Skip to content

Commit 7545baa

Browse files
authored
Improve clone_module support for non Module objects and update travis tests. (#142)
* Fix clone of None. * Update changelog. * Fix VGG flower download and untaring. * Update changelog. * Add torch 1.5 and torchvision 0.6 to travis.
1 parent 6ff649d commit 7545baa

5 files changed

Lines changed: 31 additions & 2 deletions

File tree

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ env:
1616
- 'TV=0.4.0 TORCH=1.2.0'
1717
- 'TV=0.4.1 TORCH=1.3.0'
1818
- 'TV=0.5.0 TORCH=1.4.0'
19+
- 'TV=0.6.0 TORCH=1.5.0'
1920

2021
before_install: |
2122
if [ "$TRAVIS_OS_NAME" == "osx" ]; then

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414

1515
### Fixed
1616

17+
* clone_module supports non-Module objects.
18+
* VGG flowers now relies on tarfile.open() instead of tarfile.TarFile().
19+
1720

1821
## v0.1.1
1922

learn2learn/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def clone_module(module):
9595
# First, create a copy of the module.
9696
# Adapted from:
9797
# https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
98+
if not isinstance(module, torch.nn.Module):
99+
return module
98100
clone = module.__new__(type(module))
99101
clone.__dict__ = module.__dict__.copy()
100102
clone._parameters = clone._parameters.copy()
@@ -152,6 +154,8 @@ def detach_module(module):
152154
error.backward() # Gradients are back-propagate on clone, not net.
153155
~~~
154156
"""
157+
if not isinstance(module, torch.nn.Module):
158+
return
155159
# First, re-write all parameters
156160
for param_key in module._parameters:
157161
if module._parameters[param_key] is not None:

learn2learn/vision/datasets/vgg_flowers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def download(self):
9191
req = requests.get(IMAGES_URL)
9292
with open(tar_path, 'wb') as archive:
9393
archive.write(req.content)
94-
tar_file = tarfile.TarFile(tar_path)
94+
tar_file = tarfile.open(tar_path)
9595
tar_file.extractall(data_path)
96+
tar_file.close()
9697
os.remove(tar_path)
9798

9899
label_path = os.path.join(data_path, os.path.basename(LABELS_URL))
@@ -134,5 +135,5 @@ def __len__(self):
134135
assert len(SPLITS['validation']) == 15
135136
assert len(SPLITS['test']) == 16
136137
assert len(SPLITS['all']) == 102
137-
flowers = VGGFlower102('~/data', download=True)
138+
flowers = VGGFlower102('~/vgg_data', download=True)
138139
print(len(flowers))

tests/unit/utils_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ def test_clone_module_basics(self):
9090
for a, b in zip(self.model.parameters(), cloned_model.parameters()):
9191
self.assertTrue(torch.equal(a, b))
9292

93+
def test_clone_module_nomodule(self):
94+
# Tests that we can clone non-module objects
95+
class TrickyModule(torch.nn.Module):
96+
97+
def __init__(self):
98+
super(TrickyModule, self).__init__()
99+
self.tricky_modules = torch.nn.ModuleList([
100+
torch.nn.Linear(2, 1),
101+
None,
102+
torch.nn.Linear(1, 1),
103+
])
104+
105+
model = TrickyModule()
106+
clone = l2l.clone_module(model)
107+
for i, submodule in enumerate(clone.tricky_modules):
108+
if i % 2 == 0:
109+
self.assertTrue(submodule is not None)
110+
else:
111+
self.assertTrue(submodule is None)
112+
93113
def test_clone_module_models(self):
94114
ref_models = [l2l.vision.models.OmniglotCNN(10),
95115
l2l.vision.models.MiniImagenetCNN(10)]

0 commit comments

Comments
 (0)