@@ -36,9 +36,10 @@ def __init__(self, expansion, ni, nh, stride=1,
3636 if div_groups is not None : # check if grops != 1 and div_groups
3737 groups = int (nh / div_groups )
3838 if expansion == 1 :
39- layers = [("conv_0" , conv_layer (ni , nh , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st ,
40- groups = nh if dw else groups )),
41- ("conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
39+ layers = [("conv_0" , conv_layer (ni , nh , 3 , stride = stride ,
40+ act_fn = act_fn , bn_1st = bn_1st , groups = ni if dw else groups )),
41+ ("conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn ,
42+ act = False , bn_1st = bn_1st , groups = nh if dw else groups ))
4243 ]
4344 else :
4445 layers = [("conv_0" , conv_layer (ni , nh , 1 , act_fn = act_fn , bn_1st = bn_1st )),
@@ -65,7 +66,8 @@ def _make_stem(self):
6566 bn_layer = (not self .stem_bn_end ) if i == (len (self .stem_sizes ) - 2 ) else True ,
6667 act_fn = self .act_fn , bn_1st = self .bn_1st ))
6768 for i in range (len (self .stem_sizes ) - 1 )]
68- stem .append (('stem_pool' , self .stem_pool ))
69+ if self .stem_pool is not None :
70+ stem .append (('stem_pool' , self .stem_pool ))
6971 if self .stem_bn_end :
7072 stem .append (('norm' , self .norm (self .stem_sizes [- 1 ])))
7173 return nn .Sequential (OrderedDict (stem ))
@@ -83,9 +85,10 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8385
8486
8587def _make_body (self ):
88+ stride = 2 if self .stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8689 blocks = [(f"l_{ i } " , self ._make_layer (self , self .expansion ,
8790 ni = self .block_sizes [i ], nf = self .block_sizes [i + 1 ],
88- blocks = l , stride = 1 if i == 0 else 2 ,
91+ blocks = l , stride = stride if i == 0 else 2 ,
8992 sa = self .sa if i == 0 else False ))
9093 for i , l in enumerate (self .layers )]
9194 return nn .Sequential (OrderedDict (blocks ))
@@ -113,7 +116,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
113116 zero_bn = True ,
114117 stem_stride_on = 0 ,
115118 stem_sizes = [32 , 32 , 64 ],
116- stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
119+ stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # if stem_pool is None - no pool at stem
117120 stem_bn_end = False ,
118121 _init_cnn = init_cnn ,
119122 _make_stem = _make_stem ,
@@ -127,12 +130,14 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
127130 del params ['self' ]
128131 self .__dict__ = params
129132 self ._block_sizes = params ['block_sizes' ]
133+ if type (self .stem_pool ) is str : # Hydra pass string value
134+ self .stem_pool = None
130135 if self .stem_sizes [0 ] != self .c_in :
131136 self .stem_sizes = [self .c_in ] + self .stem_sizes
132137
133138 @property
134139 def block_sizes (self ):
135- return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes + [ 256 ] * ( len ( self . layers ) - 4 )
140+ return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes
136141
137142 @property
138143 def stem (self ):
0 commit comments