@@ -123,8 +123,8 @@ def forward(self, x):
123123
124124
125125@dataclass
126- class ModelConstructorCfg :
127- """Model constructor. As default - xresnet18"""
126+ class CfgMC :
127+ """Model constructor Config . As default - xresnet18"""
128128
129129 name : str = "MC"
130130 in_chans : int = 3
@@ -167,7 +167,7 @@ def init_cnn(module: nn.Module):
167167 init_cnn (layer )
168168
169169
170- def _make_stem (self : ModelConstructorCfg ) -> nn .Sequential :
170+ def _make_stem (self : CfgMC ) -> nn .Sequential :
171171 stem : List [tuple [str , nn .Module ]] = [
172172 (f"conv_{ i } " , self .conv_layer (
173173 self .stem_sizes [i ], # type: ignore
@@ -188,7 +188,7 @@ def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
188188 return nn .Sequential (OrderedDict (stem ))
189189
190190
191- def _make_layer (self : ModelConstructorCfg , layer_num : int ) -> nn .Sequential :
191+ def _make_layer (self : CfgMC , layer_num : int ) -> nn .Sequential :
192192 # expansion, in_channels, out_channels, blocks, stride, sa):
193193 # if no pool on stem - stride = 2 for first layer block in body
194194 stride = 1 if self .stem_pool and layer_num == 0 else 2
@@ -224,7 +224,7 @@ def _make_layer(self: ModelConstructorCfg, layer_num: int) -> nn.Sequential:
224224 )
225225
226226
227- def _make_body (self : ModelConstructorCfg ) -> nn .Sequential :
227+ def _make_body (self : CfgMC ) -> nn .Sequential :
228228 return nn .Sequential (
229229 OrderedDict (
230230 [
@@ -238,7 +238,7 @@ def _make_body(self: ModelConstructorCfg) -> nn.Sequential:
238238 )
239239
240240
241- def _make_head (self : ModelConstructorCfg ) -> nn .Sequential :
241+ def _make_head (self : CfgMC ) -> nn .Sequential :
242242 head = [
243243 ("pool" , nn .AdaptiveAvgPool2d (1 )),
244244 ("flat" , nn .Flatten ()),
@@ -248,7 +248,7 @@ def _make_head(self: ModelConstructorCfg) -> nn.Sequential:
248248
249249
250250@dataclass
251- class ModelConstructor (ModelConstructorCfg ):
251+ class ModelConstructor (CfgMC ):
252252 """Model constructor. As default - xresnet18"""
253253
254254 def __post_init__ (self ):
@@ -291,7 +291,7 @@ def body(self):
291291 return self ._make_body (self ) # type: ignore
292292
293293 @classmethod
294- def from_cfg (cls , cfg : ModelConstructorCfg ):
294+ def from_cfg (cls , cfg : CfgMC ):
295295 return cls (** asdict (cfg ))
296296
297297 def __call__ (self ):
@@ -313,9 +313,11 @@ def print_cfg(self):
313313 f" layers: { self .layers } "
314314 )
315315
316- # xresnet34 = partial(
317- # ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
318- # )
319- # xresnet50 = partial(
320- # ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
321- # )
316+
317+ xresnet34 = ModelConstructor .from_cfg (
318+ CfgMC (name = "xresnet34" , expansion = 1 , layers = [3 , 4 , 6 , 3 ])
319+ )
320+
321+ xresnet50 = ModelConstructor .from_cfg (
322+ CfgMC (name = "xresnet34" , expansion = 4 , layers = [3 , 4 , 6 , 3 ])
323+ )
0 commit comments