1- # pylance: disable=overridden method
21from collections import OrderedDict
32from functools import partial
43from typing import Any , Callable , Optional , TypeVar , Union
@@ -39,7 +38,6 @@ class BasicBlock(nn.Module):
3938
4039 def __init__ (
4140 self ,
42- # expansion: int,
4341 in_channels : int ,
4442 out_channels : int ,
4543 stride : int = 1 ,
@@ -56,7 +54,6 @@ def __init__(
5654 ):
5755 super ().__init__ ()
5856 # pool defined at ModelConstructor.
59- # out_channels, in_channels = mid_channels * expansion, in_channels * expansion
6057 if div_groups is not None : # check if groups != 1 and div_groups
6158 groups = int (out_channels / div_groups )
6259 layers : ListStrMod = [
@@ -66,7 +63,7 @@ def __init__(
6663 in_channels ,
6764 out_channels ,
6865 3 ,
69- stride = stride , # type: ignore
66+ stride = stride ,
7067 act_fn = act_fn ,
7168 bn_1st = bn_1st ,
7269 groups = in_channels if dw else groups ,
@@ -114,7 +111,7 @@ def __init__(
114111 self .id_conv = None
115112 self .act_fn = get_act (act_fn )
116113
117- def forward (self , x : torch .Tensor ) -> torch .Tensor :
114+ def forward (self , x : torch .Tensor ) -> torch .Tensor : # type: ignore
118115 identity = self .id_conv (x ) if self .id_conv is not None else x
119116 return self .act_fn (self .convs (x ) + identity )
120117
@@ -177,7 +174,7 @@ def __init__(
177174 act_fn = False ,
178175 bn_1st = bn_1st ,
179176 ),
180- ), # noqa E501
177+ ),
181178 ]
182179 if se :
183180 layers .append (("se" , se (out_channels )))
@@ -208,7 +205,7 @@ def __init__(
208205 self .id_conv = None
209206 self .act_fn = get_act (act_fn )
210207
211- def forward (self , x : torch .Tensor ) -> torch .Tensor :
208+ def forward (self , x : torch .Tensor ) -> torch .Tensor : # type: ignore
212209 identity = self .id_conv (x ) if self .id_conv is not None else x
213210 return self .act_fn (self .convs (x ) + identity )
214211
@@ -234,7 +231,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
234231 stem .append (("stem_pool" , cfg .stem_pool ()))
235232 if cfg .stem_bn_end :
236233 stem .append (("norm" , cfg .norm (cfg .stem_sizes [- 1 ]))) # type: ignore
237- return nn . Sequential ( OrderedDict ( stem ) )
234+ return nn_seq ( stem )
238235
239236
240237def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
@@ -247,15 +244,12 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
247244 (
248245 f"bl_{ block_num } " ,
249246 cfg .block (
250- # cfg.expansion, # type: ignore
251- block_chs [layer_num ]
247+ block_chs [layer_num ] # type: ignore
252248 if block_num == 0
253249 else block_chs [layer_num + 1 ],
254250 block_chs [layer_num + 1 ],
255251 stride if block_num == 0 else 1 ,
256- sa = cfg .sa
257- if (block_num == num_blocks - 1 ) and layer_num == 0
258- else None ,
252+ sa = cfg .sa if (block_num == num_blocks - 1 ) and layer_num == 0 else None ,
259253 conv_layer = cfg .conv_layer ,
260254 act_fn = cfg .act_fn ,
261255 pool = cfg .pool ,
@@ -265,21 +259,17 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
265259 div_groups = cfg .div_groups ,
266260 dw = cfg .dw ,
267261 se = cfg .se ,
268- )
262+ ),
269263 )
270264 for block_num in range (num_blocks )
271265 )
272266
273267
274268def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
275269 """Create model body."""
276- return nn .Sequential (
277- OrderedDict (
278- [
279- (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
280- for layer_num in range (len (cfg .layers ))
281- ]
282- )
270+ return nn_seq (
271+ (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
272+ for layer_num in range (len (cfg .layers ))
283273 )
284274
285275
@@ -290,7 +280,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
290280 ("flat" , nn .Flatten ()),
291281 ("fc" , nn .Linear (cfg .block_sizes [- 1 ], cfg .num_classes )),
292282 ]
293- return nn . Sequential ( OrderedDict ( head ) )
283+ return nn_seq ( head )
294284
295285
296286class ModelCfg (BaseModel ):
@@ -381,25 +371,29 @@ class ModelConstructor(ModelCfg):
381371 """Model constructor. As default - xresnet18"""
382372
383373 @validator ("se" )
384- def set_se (cls , value : Union [bool , type [nn .Module ]]) -> Union [bool , type [nn .Module ]]:
374+ def set_se ( # pylint: disable=no-self-argument
375+ cls , value : Union [bool , type [nn .Module ]]
376+ ) -> Union [bool , type [nn .Module ]]:
385377 if value :
386378 if isinstance (value , (int , bool )):
387379 return SEModule
388380 return value
389381
390382 @validator ("sa" )
391- def set_sa (cls , value : Union [bool , type [nn .Module ]]) -> Union [bool , type [nn .Module ]]:
383+ def set_sa ( # pylint: disable=no-self-argument
384+ cls , value : Union [bool , type [nn .Module ]]
385+ ) -> Union [bool , type [nn .Module ]]:
392386 if value :
393387 if isinstance (value , (int , bool )):
394388 return SimpleSelfAttention # default: ks=1, sym=sym
395389 return value
396390
397- @validator ("se_module" , "se_reduction" )
398- def deprecation_warning (cls , value ): # pragma: no cover
399- print (
400- "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
401- )
402- return value
391+ @validator ("se_module" , "se_reduction" ) # pragma: no cover
392+ def deprecation_warning ( # pylint: disable=no-self-argument
393+ cls , value : Union [ bool , int , None ]
394+ ) -> Union [ bool , int , None ]:
395+ print ( "Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
396+ return value
403397
404398 @property
405399 def stem (self ):
@@ -420,9 +414,11 @@ def from_cfg(cls, cfg: ModelCfg):
420414 def __call__ (self ) -> nn .Sequential :
421415 """Create model."""
422416 model_name = self .name or self .__class__ .__name__
423- named_sequential = type (model_name , (nn .Sequential ,), {}) # create type named as model
417+ named_sequential = type (
418+ model_name , (nn .Sequential ,), {}
419+ ) # create type named as model
424420 model = named_sequential (
425- OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
421+ OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )]) # type: ignore
426422 )
427423 self .init_cnn (model ) # pylint: disable=too-many-function-args
428424 extra_repr = self .__repr_changed_args__ ()
@@ -449,4 +445,5 @@ class XResNet34(ModelConstructor):
449445
450446
451447class XResNet50 (XResNet34 ):
452- expansion : int = 4
448+ block : type [nn .Module ] = BottleneckBlock
449+ block_sizes : list [int ] = [256 , 512 , 1024 , 2048 ]
0 commit comments