1111 "act_fn" ,
1212 "ResBlock" ,
1313 "ModelConstructor" ,
14- # "xresnet34 ",
15- # "xresnet50 ",
14+ "XResNet34 " ,
15+ "XResNet50 " ,
1616]
1717
1818
@@ -119,7 +119,7 @@ def forward(self, x):
119119 return self .act_fn (self .convs (x ) + identity )
120120
121121
122- class CfgMC (BaseModel ):
122+ class ModelCfg (BaseModel ):
123123 """Model constructor Config. As default - xresnet18"""
124124
125125 name : str = "MC"
@@ -176,7 +176,7 @@ def init_cnn(module: nn.Module):
176176 init_cnn (layer )
177177
178178
179- def make_stem (self : CfgMC ) -> nn .Sequential :
179+ def make_stem (self : ModelCfg ) -> nn .Sequential :
180180 stem : List [tuple [str , nn .Module ]] = [
181181 (f"conv_{ i } " , self .conv_layer (
182182 self .stem_sizes [i ], # type: ignore
@@ -197,7 +197,7 @@ def make_stem(self: CfgMC) -> nn.Sequential:
197197 return nn .Sequential (OrderedDict (stem ))
198198
199199
200- def make_layer (cfg : CfgMC , layer_num : int ) -> nn .Sequential :
200+ def make_layer (cfg : ModelCfg , layer_num : int ) -> nn .Sequential :
201201 # expansion, in_channels, out_channels, blocks, stride, sa):
202202 # if no pool on stem - stride = 2 for first layer block in body
203203 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
@@ -233,7 +233,7 @@ def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
233233 )
234234
235235
236- def make_body (cfg : CfgMC ) -> nn .Sequential :
236+ def make_body (cfg : ModelCfg ) -> nn .Sequential :
237237 return nn .Sequential (
238238 OrderedDict (
239239 [
@@ -247,7 +247,7 @@ def make_body(cfg: CfgMC) -> nn.Sequential:
247247 )
248248
249249
250- def make_head (cfg : CfgMC ) -> nn .Sequential :
250+ def make_head (cfg : ModelCfg ) -> nn .Sequential :
251251 head = [
252252 ("pool" , nn .AdaptiveAvgPool2d (1 )),
253253 ("flat" , nn .Flatten ()),
@@ -256,7 +256,7 @@ def make_head(cfg: CfgMC) -> nn.Sequential:
256256 return nn .Sequential (OrderedDict (head ))
257257
258258
259- class ModelConstructor (CfgMC ):
259+ class ModelConstructor (ModelCfg ):
260260 """Model constructor. As default - xresnet18"""
261261
262262 def __init__ (self , ** data ):
@@ -296,7 +296,7 @@ def body(self):
296296 return self .make_body (self ) # type: ignore
297297
298298 @classmethod
299- def from_cfg (cls , cfg : CfgMC ):
299+ def from_cfg (cls , cfg : ModelCfg ):
300300 return cls (** cfg .dict ())
301301
302302 def __call__ (self ):
0 commit comments