Skip to content

Commit 8c75077

Browse files
authored
Adds OnDeviceDataset, including device option for benchmarks. (#253)
* Version bump. * Fix merge conflict in _version.py * Add scripts for supervised pretraining. * Remove TODO.md * Fix linting. * Add spec to pretrained backbones. * Add support for OnDevice in benchmarks. * Update supervised pretraining example. * Fix import in l2l.data. * Fix lightning_anil_no_travis test. * Fix scipy import. * Remove supervised pretraining for OnDevice merge. * Update changelog.
1 parent 7522003 commit 8c75077

13 files changed

Lines changed: 253 additions & 34 deletions

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
* `l2l.nn.PrototypicalClassifier` and `l2l.nn.SVMClassifier`.
1616
* Add `l2l.vision.models.WRN28`.
1717
* Separate modules for `CNN4Backbone`, `ResNet12Backbone`, `WRN28Backbones` w/ pretrained weights.
18+
* Add `l2l.data.OnDeviceDataset` and implement `device` parameter for benchmarks.
1819

1920
### Changed
2021

learn2learn/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from . import transforms
88
from .meta_dataset import MetaDataset, UnionMetaDataset, FilteredMetaDataset
99
from .task_dataset import TaskDataset, DataDescription
10+
from .utils import OnDeviceDataset, partition_task, InfiniteIterator

learn2learn/data/utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import torch
34
import requests
45
import tqdm
56

@@ -44,3 +45,98 @@ def save_response_content(response, destination):
4445
for chunk in response.iter_content(CHUNK_SIZE):
4546
if chunk: # filter out keep-alive new chunks
4647
f.write(chunk)
48+
49+
50+
class InfiniteIterator(object):
51+
52+
def __init__(self, dataloader):
53+
self.dataloader = dataloader
54+
self.iterator = iter(self.dataloader)
55+
56+
def __iter__(self):
57+
return self
58+
59+
def __next__(self):
60+
while True:
61+
try:
62+
return next(self.iterator)
63+
except StopIteration:
64+
self.iterator = iter(self.dataloader)
65+
66+
67+
def partition_task(data, labels, shots=1, ways=None):
68+
assert data.size(0) == labels.size(0)
69+
unique_labels = labels.unique()
70+
if ways is None:
71+
ways = unique_labels.numel()
72+
data_shape = data.shape[1:]
73+
num_support = ways * shots
74+
num_query = data.size(0) - num_support
75+
assert num_query % ways == 0, 'Only query_shot == support_shot supported.'
76+
query_shots = num_query // ways
77+
support_data = torch.empty(
78+
(num_support,) + data_shape,
79+
device=data.device,
80+
dtype=data.dtype,
81+
)
82+
support_labels = torch.empty(
83+
num_support,
84+
device=labels.device,
85+
dtype=labels.dtype,
86+
)
87+
query_data = torch.empty(
88+
(num_query, ) + data_shape,
89+
device=data.device,
90+
dtype=data.dtype,
91+
)
92+
query_labels = torch.empty(
93+
num_query,
94+
device=labels.device,
95+
dtype=labels.dtype,
96+
)
97+
for i, label in enumerate(unique_labels):
98+
support_start = i * shots
99+
support_end = support_start + shots
100+
query_start = i * query_shots
101+
query_end = query_start + query_shots
102+
103+
# filter data
104+
label_data = data[labels == label] # TODO: fancy index makes a copy.
105+
num_label_data = label_data.size(0)
106+
assert num_label_data == shots + query_shots, \
107+
'Only same number of query per label supported.'
108+
109+
# set value of labels
110+
support_labels[support_start:support_end].fill_(label)
111+
query_labels[query_start:query_end].fill_(label)
112+
113+
# set value of data
114+
support_data[support_start:support_end].copy_(label_data[:shots])
115+
query_data[query_start:query_end].copy_(label_data[shots:])
116+
117+
return (support_data, support_labels), (query_data, query_labels)
118+
119+
120+
class OnDeviceDataset(torch.utils.data.TensorDataset):
121+
122+
def __init__(self, dataset, device=None, transform=None):
123+
data = []
124+
labels = []
125+
for x, y in dataset:
126+
data.append(x.unsqueeze(0))
127+
labels.append(y)
128+
data = torch.cat(data, dim=0)
129+
labels = torch.tensor(labels)
130+
if device is not None:
131+
data = data.to(device)
132+
labels = labels.to(device)
133+
super(OnDeviceDataset, self).__init__(data, labels)
134+
self.transform = transform
135+
if hasattr(dataset, '_bookkeeping_path'):
136+
self._bookkeeping_path = dataset._bookkeeping_path
137+
138+
def __getitem__(self, index):
139+
x, y = super(OnDeviceDataset, self).__getitem__(index)
140+
if self.transform is not None:
141+
x = self.transform(x)
142+
return x, y

learn2learn/utils/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import copy
44
import torch
5+
import argparse
6+
import dataclasses
57

68

79
def magic_box(x):
@@ -313,6 +315,38 @@ def accuracy(preds, targets):
313315
return acc / preds.size(0)
314316

315317

318+
def flatten_config(args, prefix=None):
319+
flat_args = dict()
320+
if isinstance(args, argparse.Namespace):
321+
args = vars(args)
322+
return flatten_config(args)
323+
elif not dataclasses.is_dataclass(args) and not isinstance(args, dict):
324+
flat_args[prefix] = args
325+
return flat_args
326+
elif dataclasses.is_dataclass(args):
327+
keys = dataclasses.fields(args)
328+
def getvalue(x): getattr(args, x.name)
329+
elif isinstance(args, dict):
330+
keys = args.keys()
331+
def getvalue(x): args[x]
332+
else:
333+
raise 'Unknown args'
334+
for key in keys:
335+
value = getvalue(key)
336+
if prefix is None:
337+
if isinstance(key, str):
338+
prefix_child = key
339+
elif isinstance(key, dataclasses.Field):
340+
prefix_child = key.name
341+
else:
342+
raise 'Unknown key'
343+
else:
344+
prefix_child = prefix + '.' + key.name
345+
flat_child = flatten_config(value, prefix=prefix_child)
346+
flat_args.update(flat_child)
347+
return flat_args
348+
349+
316350
class _ImportRaiser(object):
317351

318352
def __init__(self, name, command):

learn2learn/vision/benchmarks/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_tasksets(
8181
* **test_ways** (int, *optional*, default=5) - The number of classes per test tasks. Also used for validation tasks.
8282
* **test_samples** (int, *optional*, default=10) - The number of samples per test tasks. Also used for validation tasks.
8383
* **num_tasks** (int, *optional*, default=-1) - The number of tasks in each TaskDataset.
84+
* **device** (torch.Device, *optional*, default=None) - If not None, tasksets are loaded as Tensors on `device`.
8485
* **root** (str, *optional*, default='~/data') - Where the data is stored.
8586
8687
**Example**
@@ -96,15 +97,13 @@ def get_tasksets(
9697
"""
9798
root = os.path.expanduser(root)
9899

99-
if device is not None:
100-
raise NotImplementedError('Device other than None not implemented. (yet)')
101-
102100
# Load task-specific data and transforms
103101
datasets, transforms = _TASKSETS[name](train_ways=train_ways,
104102
train_samples=train_samples,
105103
test_ways=test_ways,
106104
test_samples=test_samples,
107105
root=root,
106+
device=device,
108107
**kwargs)
109108
train_dataset, validation_dataset, test_dataset = datasets
110109
train_transforms, validation_transforms, test_transforms = transforms

learn2learn/vision/benchmarks/cifarfs_benchmark.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def cifarfs_tasksets(
1212
test_ways=5,
1313
test_samples=10,
1414
root='~/data',
15+
device=None,
1516
**kwargs,
1617
):
1718
"""Tasksets for CIFAR-FS benchmarks."""
@@ -28,6 +29,19 @@ def cifarfs_tasksets(
2829
transform=data_transform,
2930
mode='test',
3031
download=True)
32+
if device is not None:
33+
train_dataset = l2l.data.OnDeviceDataset(
34+
dataset=train_dataset,
35+
device=device,
36+
)
37+
valid_dataset = l2l.data.OnDeviceDataset(
38+
dataset=valid_dataset,
39+
device=device,
40+
)
41+
test_dataset = l2l.data.OnDeviceDataset(
42+
dataset=test_dataset,
43+
device=device,
44+
)
3145
train_dataset = l2l.data.MetaDataset(train_dataset)
3246
valid_dataset = l2l.data.MetaDataset(valid_dataset)
3347
test_dataset = l2l.data.MetaDataset(test_dataset)

learn2learn/vision/benchmarks/fc100_benchmark.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def fc100_tasksets(
1212
test_ways=5,
1313
test_samples=10,
1414
root='~/data',
15+
device=None,
1516
**kwargs,
1617
):
1718
"""Tasksets for FC100 benchmarks."""
@@ -28,6 +29,19 @@ def fc100_tasksets(
2829
transform=data_transform,
2930
mode='test',
3031
download=True)
32+
if device is not None:
33+
train_dataset = l2l.data.OnDeviceDataset(
34+
dataset=train_dataset,
35+
device=device,
36+
)
37+
valid_dataset = l2l.data.OnDeviceDataset(
38+
dataset=valid_dataset,
39+
device=device,
40+
)
41+
test_dataset = l2l.data.OnDeviceDataset(
42+
dataset=test_dataset,
43+
device=device,
44+
)
3145
train_dataset = l2l.data.MetaDataset(train_dataset)
3246
valid_dataset = l2l.data.MetaDataset(valid_dataset)
3347
test_dataset = l2l.data.MetaDataset(test_dataset)

learn2learn/vision/benchmarks/mini_imagenet_benchmark.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def mini_imagenet_tasksets(
1414
test_samples=10,
1515
root='~/data',
1616
data_augmentation=None,
17+
device=None,
1718
**kwargs,
1819
):
1920
"""Tasksets for mini-ImageNet benchmarks."""
@@ -47,21 +48,38 @@ def mini_imagenet_tasksets(
4748
train_dataset = l2l.vision.datasets.MiniImagenet(
4849
root=root,
4950
mode='train',
50-
transform=train_data_transforms,
5151
download=True,
5252
)
5353
valid_dataset = l2l.vision.datasets.MiniImagenet(
5454
root=root,
5555
mode='validation',
56-
transform=test_data_transforms,
5756
download=True,
5857
)
5958
test_dataset = l2l.vision.datasets.MiniImagenet(
6059
root=root,
6160
mode='test',
62-
transform=test_data_transforms,
6361
download=True,
6462
)
63+
if device is None:
64+
train_dataset.transform = train_data_transforms
65+
valid_dataset.transform = test_data_transforms
66+
test_dataset.transform = test_data_transforms
67+
else:
68+
train_dataset = l2l.data.OnDeviceDataset(
69+
dataset=train_dataset,
70+
transform=train_data_transforms,
71+
device=device,
72+
)
73+
valid_dataset = l2l.data.OnDeviceDataset(
74+
dataset=valid_dataset,
75+
transform=test_data_transforms,
76+
device=device,
77+
)
78+
test_dataset = l2l.data.OnDeviceDataset(
79+
dataset=test_dataset,
80+
transform=test_data_transforms,
81+
device=device,
82+
)
6583
train_dataset = l2l.data.MetaDataset(train_dataset)
6684
valid_dataset = l2l.data.MetaDataset(valid_dataset)
6785
test_dataset = l2l.data.MetaDataset(test_dataset)

learn2learn/vision/benchmarks/omniglot_benchmark.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def omniglot_tasksets(
1313
test_ways,
1414
test_samples,
1515
root,
16-
**kwargs
16+
device=None,
17+
**kwargs,
1718
):
1819
"""
1920
Benchmark definition for Omniglot.
@@ -28,18 +29,20 @@ def omniglot_tasksets(
2829
transform=data_transforms,
2930
download=True,
3031
)
32+
if device is not None:
33+
dataset = l2l.data.OnDeviceDataset(omniglot, device=device)
3134
dataset = l2l.data.MetaDataset(omniglot)
32-
train_dataset = dataset
33-
validation_datatset = dataset
34-
test_dataset = dataset
3535

3636
classes = list(range(1623))
3737
random.shuffle(classes)
38+
train_dataset = l2l.data.FilteredMetaDataset(dataset, labels=classes[:1100])
39+
validation_datatset = l2l.data.FilteredMetaDataset(dataset, labels=classes[1100:1200])
40+
test_dataset = l2l.data.FilteredMetaDataset(dataset, labels=classes[1200:])
41+
3842
train_transforms = [
3943
l2l.data.transforms.FusedNWaysKShots(dataset,
4044
n=train_ways,
41-
k=train_samples,
42-
filter_labels=classes[:1100]),
45+
k=train_samples),
4346
l2l.data.transforms.LoadData(dataset),
4447
l2l.data.transforms.RemapLabels(dataset),
4548
l2l.data.transforms.ConsecutiveLabels(dataset),
@@ -48,8 +51,7 @@ def omniglot_tasksets(
4851
validation_transforms = [
4952
l2l.data.transforms.FusedNWaysKShots(dataset,
5053
n=test_ways,
51-
k=test_samples,
52-
filter_labels=classes[1100:1200]),
54+
k=test_samples),
5355
l2l.data.transforms.LoadData(dataset),
5456
l2l.data.transforms.RemapLabels(dataset),
5557
l2l.data.transforms.ConsecutiveLabels(dataset),
@@ -58,8 +60,7 @@ def omniglot_tasksets(
5860
test_transforms = [
5961
l2l.data.transforms.FusedNWaysKShots(dataset,
6062
n=test_ways,
61-
k=test_samples,
62-
filter_labels=classes[1200:]),
63+
k=test_samples),
6364
l2l.data.transforms.LoadData(dataset),
6465
l2l.data.transforms.RemapLabels(dataset),
6566
l2l.data.transforms.ConsecutiveLabels(dataset),

0 commit comments

Comments
 (0)