@@ -167,7 +167,7 @@ def init_cnn(module: nn.Module):
167167 init_cnn (layer )
168168
169169
170- def _make_stem (self : CfgMC ) -> nn .Sequential :
170+ def make_stem (self : CfgMC ) -> nn .Sequential :
171171 stem : List [tuple [str , nn .Module ]] = [
172172 (f"conv_{ i } " , self .conv_layer (
173173 self .stem_sizes [i ], # type: ignore
@@ -188,34 +188,34 @@ def _make_stem(self: CfgMC) -> nn.Sequential:
188188 return nn .Sequential (OrderedDict (stem ))
189189
190190
191- def _make_layer ( self : CfgMC , layer_num : int ) -> nn .Sequential :
191+ def make_layer ( cfg : CfgMC , layer_num : int ) -> nn .Sequential :
192192 # expansion, in_channels, out_channels, blocks, stride, sa):
193193 # if no pool on stem - stride = 2 for first layer block in body
194- stride = 1 if self .stem_pool and layer_num == 0 else 2
195- num_blocks = self .layers [layer_num ]
196- block_chs = [self .stem_sizes [- 1 ] // self .expansion ] + self .block_sizes
194+ stride = 1 if cfg .stem_pool and layer_num == 0 else 2
195+ num_blocks = cfg .layers [layer_num ]
196+ block_chs = [cfg .stem_sizes [- 1 ] // cfg .expansion ] + cfg .block_sizes
197197 return nn .Sequential (
198198 OrderedDict (
199199 [
200200 (
201201 f"bl_{ block_num } " ,
202- self .block (
203- self .expansion , # type: ignore
202+ cfg .block (
203+ cfg .expansion , # type: ignore
204204 block_chs [layer_num ] if block_num == 0 else block_chs [layer_num + 1 ],
205205 block_chs [layer_num + 1 ],
206206 stride if block_num == 0 else 1 ,
207- sa = self .sa
207+ sa = cfg .sa
208208 if (block_num == num_blocks - 1 ) and layer_num == 0
209209 else None ,
210- conv_layer = self .conv_layer ,
211- act_fn = self .act_fn ,
212- pool = self .pool ,
213- zero_bn = self .zero_bn ,
214- bn_1st = self .bn_1st ,
215- groups = self .groups ,
216- div_groups = self .div_groups ,
217- dw = self .dw ,
218- se = self .se ,
210+ conv_layer = cfg .conv_layer ,
211+ act_fn = cfg .act_fn ,
212+ pool = cfg .pool ,
213+ zero_bn = cfg .zero_bn ,
214+ bn_1st = cfg .bn_1st ,
215+ groups = cfg .groups ,
216+ div_groups = cfg .div_groups ,
217+ dw = cfg .dw ,
218+ se = cfg .se ,
219219 ),
220220 )
221221 for block_num in range (num_blocks )
@@ -224,25 +224,25 @@ def _make_layer(self: CfgMC, layer_num: int) -> nn.Sequential:
224224 )
225225
226226
227- def _make_body ( self : CfgMC ) -> nn .Sequential :
227+ def make_body ( cfg : CfgMC ) -> nn .Sequential :
228228 return nn .Sequential (
229229 OrderedDict (
230230 [
231231 (
232232 f"l_{ layer_num } " ,
233- self ._make_layer (self , layer_num ) # type: ignore
233+ cfg ._make_layer (cfg , layer_num ) # type: ignore
234234 )
235- for layer_num in range (len (self .layers ))
235+ for layer_num in range (len (cfg .layers ))
236236 ]
237237 )
238238 )
239239
240240
241- def _make_head ( self : CfgMC ) -> nn .Sequential :
241+ def make_head ( cfg : CfgMC ) -> nn .Sequential :
242242 head = [
243243 ("pool" , nn .AdaptiveAvgPool2d (1 )),
244244 ("flat" , nn .Flatten ()),
245- ("fc" , nn .Linear (self .block_sizes [- 1 ] * self .expansion , self .num_classes )),
245+ ("fc" , nn .Linear (cfg .block_sizes [- 1 ] * cfg .expansion , cfg .num_classes )),
246246 ]
247247 return nn .Sequential (OrderedDict (head ))
248248
@@ -255,13 +255,13 @@ def __post_init__(self):
255255 if self ._init_cnn is None :
256256 self ._init_cnn = init_cnn
257257 if self ._make_stem is None :
258- self ._make_stem = _make_stem
258+ self ._make_stem = make_stem
259259 if self ._make_layer is None :
260- self ._make_layer = _make_layer
260+ self ._make_layer = make_layer
261261 if self ._make_body is None :
262- self ._make_body = _make_body
262+ self ._make_body = make_body
263263 if self ._make_head is None :
264- self ._make_head = _make_head
264+ self ._make_head = make_head
265265
266266 if self .stem_sizes [0 ] != self .in_chans :
267267 self .stem_sizes = [self .in_chans ] + self .stem_sizes
@@ -274,10 +274,6 @@ def __post_init__(self):
274274 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
275275 ) # add deprecation warning.
276276
277- # @property
278- # def block_sizes(self):
279- # return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
280-
281277 @property
282278 def stem (self ):
283279 return self ._make_stem (self ) # type: ignore
0 commit comments