File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1616 - ' TV=0.4.0 TORCH=1.2.0'
1717 - ' TV=0.4.1 TORCH=1.3.0'
1818 - ' TV=0.5.0 TORCH=1.4.0'
19+ - ' TV=0.6.0 TORCH=1.5.0'
1920
2021before_install : |
2122 if [ "$TRAVIS_OS_NAME" == "osx" ]; then
Original file line number Diff line number Diff line change @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
1515### Fixed
1616
17+ * clone_module supports non-Module objects.
18+ * VGG flowers now relies on tarfile.open() instead of tarfile.TarFile().
19+
1720
1821## v0.1.1
1922
Original file line number Diff line number Diff line change @@ -95,6 +95,8 @@ def clone_module(module):
9595 # First, create a copy of the module.
9696 # Adapted from:
9797 # https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
98+ if not isinstance (module , torch .nn .Module ):
99+ return module
98100 clone = module .__new__ (type (module ))
99101 clone .__dict__ = module .__dict__ .copy ()
100102 clone ._parameters = clone ._parameters .copy ()
@@ -152,6 +154,8 @@ def detach_module(module):
152154 error.backward() # Gradients are back-propagate on clone, not net.
153155 ~~~
154156 """
157+ if not isinstance (module , torch .nn .Module ):
158+ return
155159 # First, re-write all parameters
156160 for param_key in module ._parameters :
157161 if module ._parameters [param_key ] is not None :
Original file line number Diff line number Diff line change @@ -91,8 +91,9 @@ def download(self):
9191 req = requests .get (IMAGES_URL )
9292 with open (tar_path , 'wb' ) as archive :
9393 archive .write (req .content )
94- tar_file = tarfile .TarFile (tar_path )
94+ tar_file = tarfile .open (tar_path )
9595 tar_file .extractall (data_path )
96+ tar_file .close ()
9697 os .remove (tar_path )
9798
9899 label_path = os .path .join (data_path , os .path .basename (LABELS_URL ))
@@ -134,5 +135,5 @@ def __len__(self):
134135 assert len (SPLITS ['validation' ]) == 15
135136 assert len (SPLITS ['test' ]) == 16
136137 assert len (SPLITS ['all' ]) == 102
137- flowers = VGGFlower102 ('~/data ' , download = True )
138+ flowers = VGGFlower102 ('~/vgg_data ' , download = True )
138139 print (len (flowers ))
Original file line number Diff line number Diff line change @@ -90,6 +90,26 @@ def test_clone_module_basics(self):
9090 for a , b in zip (self .model .parameters (), cloned_model .parameters ()):
9191 self .assertTrue (torch .equal (a , b ))
9292
93+ def test_clone_module_nomodule (self ):
94+ # Tests that we can clone non-module objects
95+ class TrickyModule (torch .nn .Module ):
96+
97+ def __init__ (self ):
98+ super (TrickyModule , self ).__init__ ()
99+ self .tricky_modules = torch .nn .ModuleList ([
100+ torch .nn .Linear (2 , 1 ),
101+ None ,
102+ torch .nn .Linear (1 , 1 ),
103+ ])
104+
105+ model = TrickyModule ()
106+ clone = l2l .clone_module (model )
107+ for i , submodule in enumerate (clone .tricky_modules ):
108+ if i % 2 == 0 :
109+ self .assertTrue (submodule is not None )
110+ else :
111+ self .assertTrue (submodule is None )
112+
93113 def test_clone_module_models (self ):
94114 ref_models = [l2l .vision .models .OmniglotCNN (10 ),
95115 l2l .vision .models .MiniImagenetCNN (10 )]
You can’t perform that action at this time.
0 commit comments