@@ -66,7 +66,8 @@ def _make_stem(self):
6666 bn_layer = (not self .stem_bn_end ) if i == (len (self .stem_sizes ) - 2 ) else True ,
6767 act_fn = self .act_fn , bn_1st = self .bn_1st ))
6868 for i in range (len (self .stem_sizes ) - 1 )]
69- stem .append (('stem_pool' , self .stem_pool ))
69+ if self .stem_pool is not None :
70+ stem .append (('stem_pool' , self .stem_pool ))
7071 if self .stem_bn_end :
7172 stem .append (('norm' , self .norm (self .stem_sizes [- 1 ])))
7273 return nn .Sequential (OrderedDict (stem ))
@@ -84,9 +85,10 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8485
8586
8687def _make_body (self ):
88+ stride = 2 if self .stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8789 blocks = [(f"l_{ i } " , self ._make_layer (self , self .expansion ,
8890 ni = self .block_sizes [i ], nf = self .block_sizes [i + 1 ],
89- blocks = l , stride = 1 if i == 0 else 2 ,
91+ blocks = l , stride = stride if i == 0 else 2 ,
9092 sa = self .sa if i == 0 else False ))
9193 for i , l in enumerate (self .layers )]
9294 return nn .Sequential (OrderedDict (blocks ))
@@ -114,7 +116,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
114116 zero_bn = True ,
115117 stem_stride_on = 0 ,
116118 stem_sizes = [32 , 32 , 64 ],
117- stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
119+ stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # if stem_pool is None - no pool at stem
118120 stem_bn_end = False ,
119121 _init_cnn = init_cnn ,
120122 _make_stem = _make_stem ,
@@ -128,6 +130,8 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
128130 del params ['self' ]
129131 self .__dict__ = params
130132 self ._block_sizes = params ['block_sizes' ]
133+ if type (self .stem_pool ) is str : # Hydra pass string value
134+ self .stem_pool = None
131135 if self .stem_sizes [0 ] != self .c_in :
132136 self .stem_sizes = [self .c_in ] + self .stem_sizes
133137
0 commit comments