11#!/usr/bin/env python3
22
3- import os
43import torch
54import torch .nn as nn
65import torch .nn .functional as F
@@ -167,10 +166,10 @@ class ResNet12Backbone(nn.Module):
167166
168167 def __init__ (
169168 self ,
170- keep_prob = 1.0 , # dropout for embedding
171169 avg_pool = True , # Set to False for 16000-dim embeddings
172170 wider = True , # True mimics MetaOptNet, False mimics TADAM
173- drop_rate = 0.1 , # dropout for residual layers
171+ embedding_dropout = 0.0 , # dropout for embedding
172+ dropblock_dropout = 0.1 , # dropout for residual layers
174173 dropblock_size = 5 ,
175174 channels = 3 ,
176175 ):
@@ -186,27 +185,27 @@ def __init__(
186185 block ,
187186 num_filters [0 ],
188187 stride = 2 ,
189- drop_rate = drop_rate ,
188+ dropblock_dropout = dropblock_dropout ,
190189 )
191190 self .layer2 = self ._make_layer (
192191 block ,
193192 num_filters [1 ],
194193 stride = 2 ,
195- drop_rate = drop_rate ,
194+ dropblock_dropout = dropblock_dropout ,
196195 )
197196 self .layer3 = self ._make_layer (
198197 block ,
199198 num_filters [2 ],
200199 stride = 2 ,
201- drop_rate = drop_rate ,
200+ dropblock_dropout = dropblock_dropout ,
202201 drop_block = True ,
203202 block_size = dropblock_size ,
204203 )
205204 self .layer4 = self ._make_layer (
206205 block ,
207206 num_filters [3 ],
208207 stride = 2 ,
209- drop_rate = drop_rate ,
208+ dropblock_dropout = dropblock_dropout ,
210209 drop_block = True ,
211210 block_size = dropblock_size ,
212211 )
@@ -215,10 +214,10 @@ def __init__(
215214 else :
216215 self .avgpool = l2l .nn .Lambda (lambda x : x )
217216 self .flatten = l2l .nn .Flatten ()
218- self .keep_prob = keep_prob
217+ self .embedding_dropout = embedding_dropout
219218 self .keep_avg_pool = avg_pool
220- self .dropout = nn .Dropout (p = 1.0 - self .keep_prob , inplace = False )
221- self .drop_rate = drop_rate
219+ self .dropout = nn .Dropout (p = self .embedding_dropout , inplace = False )
220+ self .dropblock_dropout = dropblock_dropout
222221
223222 for m in self .modules ():
224223 if isinstance (m , nn .Conv2d ):
@@ -236,7 +235,7 @@ def _make_layer(
236235 block ,
237236 planes ,
238237 stride = 1 ,
239- drop_rate = 0.0 ,
238+ dropblock_dropout = 0.0 ,
240239 drop_block = False ,
241240 block_size = 1 ,
242241 ):
@@ -253,7 +252,7 @@ def _make_layer(
253252 planes ,
254253 stride ,
255254 downsample ,
256- drop_rate ,
255+ dropblock_dropout ,
257256 drop_block ,
258257 block_size )
259258 )
@@ -301,14 +300,14 @@ class ResNet12(nn.Module):
301300
302301 **Arguments**
303302
304- * **output_size** (int) - The dimensionality of the output.
303+ * **output_size** (int) - The dimensionality of the output (eg, number of classes) .
305304 * **hidden_size** (list, *optional*, default=640) - Size of the embedding once features are extracted.
306305 (640 is for mini-ImageNet; used for the classifier layer)
307- * **keep_prob** (float, *optional*, default=1.0) - Dropout rate on the embedding layer.
308306 * **avg_pool** (bool, *optional*, default=True) - Set to False for the 16k-dim embeddings of Lee et al, 2019.
309307 * **wider** (bool, *optional*, default=True) - True uses (64, 160, 320, 640) filters akin to Lee et al, 2019.
310308 False uses (64, 128, 256, 512) filters, akin to Oreshkin et al, 2018.
311- * **drop_rate** (float, *optional*, default=0.1) - Dropout rate for the residual layers.
309+ * **embedding_dropout** (float, *optional*, default=0.0) - Dropout rate on the flattened embedding layer.
310+ * **dropblock_dropout** (float, *optional*, default=0.1) - Dropout rate for the residual layers.
312311 * **dropblock_size** (int, *optional*, default=5) - Size of drop blocks.
313312
314313 **Example**
@@ -321,19 +320,19 @@ def __init__(
321320 self ,
322321 output_size ,
323322 hidden_size = 640 , # mini-ImageNet images, used for the classifier
324- keep_prob = 1.0 , # dropout for embedding
325323 avg_pool = True , # Set to False for 16000-dim embeddings
326324 wider = True , # True mimics MetaOptNet, False mimics TADAM
327- drop_rate = 0.1 , # dropout for residual layers
325+ embedding_dropout = 0.0 , # dropout for embedding
326+ dropblock_dropout = 0.1 , # dropout for residual layers
328327 dropblock_size = 5 ,
329328 channels = 3 ,
330329 ):
331330 super (ResNet12 , self ).__init__ ()
332331 self .features = ResNet12Backbone (
333- keep_prob = keep_prob ,
334332 avg_pool = avg_pool ,
335333 wider = wider ,
336- drop_rate = drop_rate ,
334+ embedding_dropout = embedding_dropout ,
335+ dropblock_dropout = dropblock_dropout ,
337336 dropblock_size = dropblock_size ,
338337 channels = channels ,
339338 )
@@ -346,10 +345,9 @@ def forward(self, x):
346345
347346
348347if __name__ == '__main__' :
349- model = ResNet12 (output_size = 5 , avg_pool = False , drop_rate = 0.0 )
348+ model = ResNet12 (output_size = 5 , avg_pool = False , dropblock_dropout = 0.0 )
350349 img = torch .randn (5 , 3 , 84 , 84 )
351350 model = model .to ('cuda' )
352351 img = img .to ('cuda' )
353352 out = model .features (img )
354353 print (out .shape )
355- __import__ ('pdb' ).set_trace ()
0 commit comments