@@ -34,7 +34,8 @@ def init_cnn(module: nn.Module) -> None:
3434
3535
3636class BasicBlock (nn .Module ):
37- """Basic Resnet block."""
37+ """Basic Resnet block.
38+ Configurable - can use pool to reduce at identity path, change act etc. """
3839
3940 def __init__ (
4041 self ,
@@ -117,7 +118,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
117118
118119
119120class BottleneckBlock (nn .Module ):
120- """Bottleneck Resnet block."""
121+ """Bottleneck Resnet block.
122+ Configurable - can use pool to reduce at identity path, change act etc. """
121123
122124 def __init__ (
123125 self ,
@@ -211,21 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
211213
212214
213215def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
214- """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
215- len_stem = len (cfg .stem_sizes )
216+ """Create Resnet stem."""
216217 stem : ListStrMod = [
217218 (
218- f"conv_ { i } " ,
219+ "conv_1 " ,
219220 cfg .conv_layer (
220- cfg .stem_sizes [i - 1 ] if i else cfg .in_chans , # type: ignore
221- cfg .stem_sizes [i ],
222- stride = 2 if i == cfg .stem_stride_on else 1 ,
223- bn_layer = (not cfg .stem_bn_end ) if i == (len_stem - 1 ) else True ,
221+ cfg .in_chans , # type: ignore
222+ cfg .stem_sizes [- 1 ],
223+ kernel_size = 7 ,
224+ stride = 2 ,
225+ padding = 3 ,
226+ bn_layer = not cfg .stem_bn_end ,
224227 act_fn = cfg .act_fn ,
225228 bn_1st = cfg .bn_1st ,
226229 ),
227230 )
228- for i in range (len_stem )
229231 ]
230232 if cfg .stem_pool :
231233 stem .append (("stem_pool" , cfg .stem_pool ()))
@@ -295,9 +297,7 @@ class ModelCfg(BaseModel):
295297 layers : list [int ] = [2 , 2 , 2 , 2 ]
296298 norm : type [nn .Module ] = nn .BatchNorm2d
297299 act_fn : type [nn .Module ] = nn .ReLU
298- pool : Callable [[Any ], nn .Module ] = partial (
299- nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
300- )
300+ pool : Optional [Callable [[Any ], nn .Module ]] = None
301301 expansion : int = 1
302302 groups : int = 1
303303 dw : bool = False
@@ -309,7 +309,7 @@ class ModelCfg(BaseModel):
309309 bn_1st : bool = True
310310 zero_bn : bool = True
311311 stem_stride_on : int = 0
312- stem_sizes : list [int ] = [32 , 32 , 64 ]
312+ stem_sizes : list [int ] = [64 ]
313313 stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
314314 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
315315 )
@@ -368,7 +368,7 @@ def print_changed(self) -> None:
368368
369369
370370class ModelConstructor (ModelCfg ):
371- """Model constructor. As default - xresnet18 """
371+ """Model constructor. As default - resnet18 """
372372
373373 @validator ("se" )
374374 def set_se ( # pylint: disable=no-self-argument
@@ -446,10 +446,10 @@ def __repr__(self) -> str:
446446 )
447447
448448
449- class XResNet34 (ModelConstructor ):
449+ class ResNet34 (ModelConstructor ):
450450 layers : list [int ] = [3 , 4 , 6 , 3 ]
451451
452452
453- class XResNet50 ( XResNet34 ):
453+ class ResNet50 ( ResNet34 ):
454454 block : type [nn .Module ] = BottleneckBlock
455455 block_sizes : list [int ] = [256 , 512 , 1024 , 2048 ]
0 commit comments