11#!/usr/bin/env python3
22
3- import random
3+ """
4+ Demonstrates how to:
5+ * use the MAML wrapper for fast-adaptation,
6+ * use the benchmark interface to load Omniglot, and
7+ * sample tasks and split them in adaptation and evaluation sets.
8+ """
49
10+ import random
511import numpy as np
612import torch
7- from PIL . Image import LANCZOS
13+ import learn2learn as l2l
814
915from torch import nn , optim
10- from torchvision import transforms
1116
12- import learn2learn as l2l
1317
1418
1519def accuracy (predictions , targets ):
@@ -62,58 +66,15 @@ def main(
6266 torch .cuda .manual_seed (seed )
6367 device = torch .device ('cuda' )
6468
65- omniglot = l2l .vision .datasets .FullOmniglot (root = '~/data' ,
66- transform = transforms .Compose ([
67- transforms .Resize (28 , interpolation = LANCZOS ),
68- transforms .ToTensor (),
69- lambda x : 1.0 - x ,
70- ]),
71- download = True )
72- dataset = l2l .data .MetaDataset (omniglot )
73- classes = list (range (1623 ))
74- random .shuffle (classes )
75-
76- train_transforms = [
77- l2l .data .transforms .FusedNWaysKShots (dataset ,
78- n = ways ,
79- k = 2 * shots ,
80- filter_labels = classes [:1100 ]),
81- l2l .data .transforms .LoadData (dataset ),
82- l2l .data .transforms .RemapLabels (dataset ),
83- l2l .data .transforms .ConsecutiveLabels (dataset ),
84- l2l .vision .transforms .RandomClassRotation (dataset , [0.0 , 90.0 , 180.0 , 270.0 ])
85- ]
86- train_tasks = l2l .data .TaskDataset (dataset ,
87- task_transforms = train_transforms ,
88- num_tasks = 20000 )
89-
90- valid_transforms = [
91- l2l .data .transforms .FusedNWaysKShots (dataset ,
92- n = ways ,
93- k = 2 * shots ,
94- filter_labels = classes [1100 :1200 ]),
95- l2l .data .transforms .LoadData (dataset ),
96- l2l .data .transforms .RemapLabels (dataset ),
97- l2l .data .transforms .ConsecutiveLabels (dataset ),
98- l2l .vision .transforms .RandomClassRotation (dataset , [0.0 , 90.0 , 180.0 , 270.0 ])
99- ]
100- valid_tasks = l2l .data .TaskDataset (dataset ,
101- task_transforms = valid_transforms ,
102- num_tasks = 1024 )
103-
104- test_transforms = [
105- l2l .data .transforms .FusedNWaysKShots (dataset ,
106- n = ways ,
107- k = 2 * shots ,
108- filter_labels = classes [1200 :]),
109- l2l .data .transforms .LoadData (dataset ),
110- l2l .data .transforms .RemapLabels (dataset ),
111- l2l .data .transforms .ConsecutiveLabels (dataset ),
112- l2l .vision .transforms .RandomClassRotation (dataset , [0.0 , 90.0 , 180.0 , 270.0 ])
113- ]
114- test_tasks = l2l .data .TaskDataset (dataset ,
115- task_transforms = test_transforms ,
116- num_tasks = 1024 )
69+ # Load train/validation/test tasksets using the benchmark interface
70+ tasksets = l2l .vision .benchmarks .get_tasksets ('omniglot' ,
71+ train_ways = ways ,
72+ train_samples = 2 * shots ,
73+ test_ways = ways ,
74+ test_samples = 2 * shots ,
75+ num_tasks = 20000 ,
76+ root = '~/data' ,
77+ )
11778
11879 # Create model
11980 model = l2l .vision .models .OmniglotFC (28 ** 2 , ways )
@@ -131,7 +92,7 @@ def main(
13192 for task in range (meta_batch_size ):
13293 # Compute meta-training loss
13394 learner = maml .clone ()
134- batch = train_tasks .sample ()
95+ batch = tasksets . train .sample ()
13596 evaluation_error , evaluation_accuracy = fast_adapt (batch ,
13697 learner ,
13798 loss ,
@@ -145,7 +106,7 @@ def main(
145106
146107 # Compute meta-validation loss
147108 learner = maml .clone ()
148- batch = valid_tasks .sample ()
109+ batch = tasksets . validation .sample ()
149110 evaluation_error , evaluation_accuracy = fast_adapt (batch ,
150111 learner ,
151112 loss ,
@@ -174,7 +135,7 @@ def main(
174135 for task in range (meta_batch_size ):
175136 # Compute meta-testing loss
176137 learner = maml .clone ()
177- batch = test_tasks .sample ()
138+ batch = tasksets . test .sample ()
178139 evaluation_error , evaluation_accuracy = fast_adapt (batch ,
179140 learner ,
180141 loss ,
0 commit comments