Skip to content

Commit a8c8985

Browse files
authored
Add l2l.vision.benchmarks interface. (#146)
* Add Omniglot benchmarks and tests. * Update Omniglot example to use the MAML interface. * Add mini-ImageNet benchmark. * Add tiered-ImageNet, CIAR-FS, and FC100 benhmarks. * Fix linting. * Add docs. * Update docs. * Fix tiered-ImageNet and docs. * Automatically ownload datasets in benchmarks. * Fix docs fonts. * Add download option to mini-imagenet. * Add download options to FC100 and CIFAR-FS. * Add dropbox links for mini-imagenet * Change test directory download folder for benchmarks.à * Prefer dropbox than gdrive for downloading mini-ImageNet. * Omit tiered-imagenet from benchmark tests. * Fix lint.
1 parent 7545baa commit a8c8985

19 files changed

Lines changed: 595 additions & 143 deletions

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
* Add l2l.vision.benchmarks interface.
14+
1315
### Changed
1416

1517
### Fixed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ alltests:
4444
make notravis-tests >>alltests.txt 2>&1
4545

4646
docs:
47+
rm -f docs/mkdocs.yml
4748
cd docs && pydocmd build && pydocmd serve
4849

4950
docs-deploy:
51+
rm -f docs/mkdocs.yml
5052
cd docs && pydocmd gh-deploy
5153

5254
# https://dev.to/neshaz/a-tutorial-for-tagging-releases-in-git-147e

docs/pydocmd.yml

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ generate:
5656
- learn2learn.vision.datasets.FGVCAircraft
5757
- learn2learn.vision.transforms:
5858
- learn2learn.vision.transforms.RandomClassRotation
59+
- learn2learn.vision.benchmarks:
60+
- learn2learn.vision.benchmarks.list_tasksets
61+
- learn2learn.vision.benchmarks.get_tasksets
5962
- docs/learn2learn.text.md:
6063
- learn2learn.text.datasets.NewsClassification
6164

@@ -93,15 +96,15 @@ repo_name: 'learnables/learn2learn'
9396
repo_url: 'https://github.com/learnables/learn2learn'
9497

9598
theme:
96-
name: 'material'
97-
logo: 'assets/img/learn2learn_white.png'
98-
favicon: 'assets/img/favicons/favicon.ico'
99-
palette:
100-
primary: 'blue'
101-
accent: 'orange'
102-
font:
103-
text: 'Source Sans Pro'
104-
code: 'Ubuntu Mono'
99+
name: 'material'
100+
logo: 'assets/img/learn2learn_white.png'
101+
favicon: 'assets/img/favicons/favicon.ico'
102+
palette:
103+
primary: 'blue'
104+
accent: 'orange'
105+
font:
106+
text: 'Source Sans Pro'
107+
code: 'Ubuntu Mono'
105108

106109
extra:
107110
social:
@@ -127,22 +130,22 @@ headers: markdown
127130
# subdirectory of your project (eg. docs/), you may want to add the parent
128131
# directory here.
129132
additional_search_paths:
130-
- ..
133+
- ..
131134

132135
extra_javascript:
133-
- https://cdn.jsdelivr.net/npm/katex/dist/katex.min.js
134-
- https://cdn.jsdelivr.net/npm/katex/dist/contrib/mathtex-script-type.min.js
136+
- https://cdn.jsdelivr.net/npm/katex/dist/katex.min.js
137+
- https://cdn.jsdelivr.net/npm/katex/dist/contrib/mathtex-script-type.min.js
135138

136139
extra_css:
137-
- https://cdn.jsdelivr.net/npm/katex/dist/katex.min.css
138-
- 'assets/css/l2l_material.css'
140+
- https://cdn.jsdelivr.net/npm/katex/dist/katex.min.css
141+
- 'assets/css/l2l_material.css'
139142

140143
# Extensions
141144
markdown_extensions:
142-
- mdx_math
143-
- admonition
144-
- codehilite:
145-
guess_lang: true
146-
linenums: true
147-
- toc:
148-
permalink: true
145+
- mdx_math
146+
- admonition
147+
- codehilite:
148+
guess_lang: true
149+
linenums: true
150+
- toc:
151+
permalink: true

examples/vision/maml_miniimagenet.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
#!/usr/bin/env python3
22

3-
import random
3+
"""
4+
Demonstrates how to:
5+
* use the MAML wrapper for fast-adaptation,
6+
* use the benchmark interface to load mini-ImageNet, and
7+
* sample tasks and split them in adaptation and evaluation sets.
8+
9+
To contrast the use of the benchmark interface with directly instantiating mini-ImageNet datasets and tasks, compare with `protonet_miniimagenet.py`.
10+
"""
411

12+
import random
513
import numpy as np
14+
615
import torch
7-
from torch import nn
8-
from torch import optim
16+
from torch import nn, optim
917

1018
import learn2learn as l2l
11-
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
19+
from learn2learn.data.transforms import (NWays,
20+
KShots,
21+
LoadData,
22+
RemapLabels,
23+
ConsecutiveLabels)
1224

1325

1426
def accuracy(predictions, targets):
@@ -61,46 +73,14 @@ def main(
6173
torch.cuda.manual_seed(seed)
6274
device = torch.device('cuda')
6375

64-
# Create Datasets
65-
train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train')
66-
valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation')
67-
test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test')
68-
train_dataset = l2l.data.MetaDataset(train_dataset)
69-
valid_dataset = l2l.data.MetaDataset(valid_dataset)
70-
test_dataset = l2l.data.MetaDataset(test_dataset)
71-
72-
train_transforms = [
73-
NWays(train_dataset, ways),
74-
KShots(train_dataset, 2*shots),
75-
LoadData(train_dataset),
76-
RemapLabels(train_dataset),
77-
ConsecutiveLabels(train_dataset),
78-
]
79-
train_tasks = l2l.data.TaskDataset(train_dataset,
80-
task_transforms=train_transforms,
81-
num_tasks=20000)
82-
83-
valid_transforms = [
84-
NWays(valid_dataset, ways),
85-
KShots(valid_dataset, 2*shots),
86-
LoadData(valid_dataset),
87-
ConsecutiveLabels(valid_dataset),
88-
RemapLabels(valid_dataset),
89-
]
90-
valid_tasks = l2l.data.TaskDataset(valid_dataset,
91-
task_transforms=valid_transforms,
92-
num_tasks=600)
93-
94-
test_transforms = [
95-
NWays(test_dataset, ways),
96-
KShots(test_dataset, 2*shots),
97-
LoadData(test_dataset),
98-
RemapLabels(test_dataset),
99-
ConsecutiveLabels(test_dataset),
100-
]
101-
test_tasks = l2l.data.TaskDataset(test_dataset,
102-
task_transforms=test_transforms,
103-
num_tasks=600)
76+
# Create Tasksets using the benchmark interface
77+
tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
78+
train_samples=2*shots,
79+
train_ways=ways,
80+
test_samples=2*shots,
81+
test_ways=ways,
82+
root='~/data',
83+
)
10484

10585
# Create model
10686
model = l2l.vision.models.MiniImagenetCNN(ways)
@@ -118,7 +98,7 @@ def main(
11898
for task in range(meta_batch_size):
11999
# Compute meta-training loss
120100
learner = maml.clone()
121-
batch = train_tasks.sample()
101+
batch = tasksets.train.sample()
122102
evaluation_error, evaluation_accuracy = fast_adapt(batch,
123103
learner,
124104
loss,
@@ -132,7 +112,7 @@ def main(
132112

133113
# Compute meta-validation loss
134114
learner = maml.clone()
135-
batch = valid_tasks.sample()
115+
batch = tasksets.validation.sample()
136116
evaluation_error, evaluation_accuracy = fast_adapt(batch,
137117
learner,
138118
loss,
@@ -161,7 +141,7 @@ def main(
161141
for task in range(meta_batch_size):
162142
# Compute meta-testing loss
163143
learner = maml.clone()
164-
batch = test_tasks.sample()
144+
batch = tasksets.test.sample()
165145
evaluation_error, evaluation_accuracy = fast_adapt(batch,
166146
learner,
167147
loss,

examples/vision/maml_omniglot.py

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
#!/usr/bin/env python3
22

3-
import random
3+
"""
4+
Demonstrates how to:
5+
* use the MAML wrapper for fast-adaptation,
6+
* use the benchmark interface to load Omniglot, and
7+
* sample tasks and split them in adaptation and evaluation sets.
8+
"""
49

10+
import random
511
import numpy as np
612
import torch
7-
from PIL.Image import LANCZOS
13+
import learn2learn as l2l
814

915
from torch import nn, optim
10-
from torchvision import transforms
1116

12-
import learn2learn as l2l
1317

1418

1519
def accuracy(predictions, targets):
@@ -62,58 +66,15 @@ def main(
6266
torch.cuda.manual_seed(seed)
6367
device = torch.device('cuda')
6468

65-
omniglot = l2l.vision.datasets.FullOmniglot(root='~/data',
66-
transform=transforms.Compose([
67-
transforms.Resize(28, interpolation=LANCZOS),
68-
transforms.ToTensor(),
69-
lambda x: 1.0 - x,
70-
]),
71-
download=True)
72-
dataset = l2l.data.MetaDataset(omniglot)
73-
classes = list(range(1623))
74-
random.shuffle(classes)
75-
76-
train_transforms = [
77-
l2l.data.transforms.FusedNWaysKShots(dataset,
78-
n=ways,
79-
k=2*shots,
80-
filter_labels=classes[:1100]),
81-
l2l.data.transforms.LoadData(dataset),
82-
l2l.data.transforms.RemapLabels(dataset),
83-
l2l.data.transforms.ConsecutiveLabels(dataset),
84-
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
85-
]
86-
train_tasks = l2l.data.TaskDataset(dataset,
87-
task_transforms=train_transforms,
88-
num_tasks=20000)
89-
90-
valid_transforms = [
91-
l2l.data.transforms.FusedNWaysKShots(dataset,
92-
n=ways,
93-
k=2*shots,
94-
filter_labels=classes[1100:1200]),
95-
l2l.data.transforms.LoadData(dataset),
96-
l2l.data.transforms.RemapLabels(dataset),
97-
l2l.data.transforms.ConsecutiveLabels(dataset),
98-
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
99-
]
100-
valid_tasks = l2l.data.TaskDataset(dataset,
101-
task_transforms=valid_transforms,
102-
num_tasks=1024)
103-
104-
test_transforms = [
105-
l2l.data.transforms.FusedNWaysKShots(dataset,
106-
n=ways,
107-
k=2*shots,
108-
filter_labels=classes[1200:]),
109-
l2l.data.transforms.LoadData(dataset),
110-
l2l.data.transforms.RemapLabels(dataset),
111-
l2l.data.transforms.ConsecutiveLabels(dataset),
112-
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
113-
]
114-
test_tasks = l2l.data.TaskDataset(dataset,
115-
task_transforms=test_transforms,
116-
num_tasks=1024)
69+
# Load train/validation/test tasksets using the benchmark interface
70+
tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
71+
train_ways=ways,
72+
train_samples=2*shots,
73+
test_ways=ways,
74+
test_samples=2*shots,
75+
num_tasks=20000,
76+
root='~/data',
77+
)
11778

11879
# Create model
11980
model = l2l.vision.models.OmniglotFC(28 ** 2, ways)
@@ -131,7 +92,7 @@ def main(
13192
for task in range(meta_batch_size):
13293
# Compute meta-training loss
13394
learner = maml.clone()
134-
batch = train_tasks.sample()
95+
batch = tasksets.train.sample()
13596
evaluation_error, evaluation_accuracy = fast_adapt(batch,
13697
learner,
13798
loss,
@@ -145,7 +106,7 @@ def main(
145106

146107
# Compute meta-validation loss
147108
learner = maml.clone()
148-
batch = valid_tasks.sample()
109+
batch = tasksets.validation.sample()
149110
evaluation_error, evaluation_accuracy = fast_adapt(batch,
150111
learner,
151112
loss,
@@ -174,7 +135,7 @@ def main(
174135
for task in range(meta_batch_size):
175136
# Compute meta-testing loss
176137
learner = maml.clone()
177-
batch = test_tasks.sample()
138+
batch = tasksets.test.sample()
178139
evaluation_error, evaluation_accuracy = fast_adapt(batch,
179140
learner,
180141
loss,

learn2learn/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from . import datasets
44
from . import models
55
from . import transforms
6+
from . import benchmarks

0 commit comments

Comments
 (0)