@@ -131,7 +131,7 @@ class ModelConstructorCfg:
131131 num_classes : int = 1000
132132 block : Type [nn .Module ] = ResBlock
133133 conv_layer : Type [nn .Module ] = ConvBnAct
134- _block_sizes : List [int ] = field (default_factory = lambda : [64 , 128 , 256 , 512 ])
134+ block_sizes : List [int ] = field (default_factory = lambda : [64 , 128 , 256 , 512 ])
135135 layers : List [int ] = field (default_factory = lambda : [2 , 2 , 2 , 2 ])
136136 norm : Type [nn .Module ] = nn .BatchNorm2d
137137 act_fn : nn .Module = nn .ReLU (inplace = True )
@@ -150,11 +150,11 @@ class ModelConstructorCfg:
150150 stem_sizes : List [int ] = field (default_factory = lambda : [32 , 32 , 64 ])
151151 stem_pool : Union [nn .Module , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
152152 stem_bn_end : bool = False
153- _init_cnn : Optional [Callable [[nn .Module ], None ]] = None
154- _make_stem : Optional [Callable ] = None
155- _make_layer : Optional [Callable ] = None
156- _make_body : Optional [Callable ] = None
157- _make_head : Optional [Callable ] = None
153+ _init_cnn : Optional [Callable [[nn .Module ], None ]] = field ( repr = False , default = None )
154+ _make_stem : Optional [Callable ] = field ( repr = False , default = None )
155+ _make_layer : Optional [Callable ] = field ( repr = False , default = None )
156+ _make_body : Optional [Callable ] = field ( repr = False , default = None )
157+ _make_head : Optional [Callable ] = field ( repr = False , default = None )
158158
159159
160160def init_cnn (module : nn .Module ):
@@ -188,22 +188,21 @@ def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
188188 return nn .Sequential (OrderedDict (stem ))
189189
190190
191- def _make_layer (self , layer_num : int ) -> nn .Module :
191+ def _make_layer (self : ModelConstructorCfg , layer_num : int ) -> nn .Sequential :
192192 # expansion, in_channels, out_channels, blocks, stride, sa):
193193 # if no pool on stem - stride = 2 for first layer block in body
194194 stride = 1 if self .stem_pool and layer_num == 0 else 2
195195 num_blocks = self .layers [layer_num ]
196+ block_chs = [self .stem_sizes [- 1 ] // self .expansion ] + self .block_sizes
196197 return nn .Sequential (
197198 OrderedDict (
198199 [
199200 (
200201 f"bl_{ block_num } " ,
201202 self .block (
202- self .expansion ,
203- self .block_sizes [layer_num ]
204- if block_num == 0
205- else self .block_sizes [layer_num + 1 ],
206- self .block_sizes [layer_num + 1 ],
203+ self .expansion , # type: ignore
204+ block_chs [layer_num ] if block_num == 0 else block_chs [layer_num + 1 ],
205+ block_chs [layer_num + 1 ],
207206 stride if block_num == 0 else 1 ,
208207 sa = self .sa
209208 if (block_num == num_blocks - 1 ) and layer_num == 0
@@ -225,7 +224,7 @@ def _make_layer(self, layer_num: int) -> nn.Module:
225224 )
226225
227226
228- def _make_body (self ) :
227+ def _make_body (self : ModelConstructorCfg ) -> nn . Sequential :
229228 return nn .Sequential (
230229 OrderedDict (
231230 [
@@ -275,9 +274,9 @@ def __post_init__(self):
275274 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
276275 ) # add deprecation warning.
277276
278- @property
279- def block_sizes (self ):
280- return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes
277+ # @property
278+ # def block_sizes(self):
279+ # return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
281280
282281 @property
283282 def stem (self ):
@@ -303,18 +302,17 @@ def __call__(self):
303302 model .extra_repr = lambda : f"{ self .name } "
304303 return model
305304
306- def __repr__ (self ):
307- return (
305+ def simple_cfg (self ):
306+ print (
308307 f"{ self .name } constructor\n "
309308 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
310309 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
311310 f" sa: { self .sa } , se: { self .se } \n "
312311 f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
313- f" body sizes { self ._block_sizes } \n "
312+ f" body sizes { self .block_sizes } \n "
314313 f" layers: { self .layers } "
315314 )
316315
317-
318316# xresnet34 = partial(
319317# ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
320318# )
0 commit comments