11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3+ from typing import Any , Callable , Optional , TypeVar , Union
44
55import torch .nn as nn
66from pydantic import BaseModel , root_validator
1919TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
2020
2121
22- def init_cnn (module : nn .Module ):
22+ def init_cnn (module : nn .Module ) -> None :
2323 "Init module - kaiming_normal for Conv2d and 0 for biases."
2424 if getattr (module , "bias" , None ) is not None :
2525 nn .init .constant_ (module .bias , 0 ) # type: ignore
@@ -39,7 +39,7 @@ def __init__(
3939 mid_channels : int ,
4040 stride : int = 1 ,
4141 conv_layer = ConvBnAct ,
42- act_fn : Type [nn .Module ] = nn .ReLU ,
42+ act_fn : type [nn .Module ] = nn .ReLU ,
4343 zero_bn : bool = True ,
4444 bn_1st : bool = True ,
4545 groups : int = 1 ,
@@ -144,16 +144,17 @@ def __init__(
144144 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
145145 else :
146146 self .id_conv = None
147- self .act_fn = get_act (act_fn ) # type: ignore
147+ self .act_fn = get_act (act_fn )
148148
149149 def forward (self , x ):
150150 identity = self .id_conv (x ) if self .id_conv is not None else x
151151 return self .act_fn (self .convs (x ) + identity )
152152
153153
154154def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155+ """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155156 len_stem = len (cfg .stem_sizes )
156- stem : List [tuple [str , nn .Module ]] = [
157+ stem : list [tuple [str , nn .Module ]] = [
157158 (
158159 f"conv_{ i } " ,
159160 cfg .conv_layer (
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175176
176177
177178def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
178- # expansion, in_channels, out_channels, blocks, stride, sa):
179+ """Create layer (stage)"""
179180 # if no pool on stem - stride = 2 for first layer block in body
180181 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
181182 num_blocks = cfg .layers [layer_num ]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213214
214215
215216def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
217+ """Create model body."""
216218 return nn .Sequential (
217219 OrderedDict (
218220 [
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224226
225227
226228def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
229+ """Create head."""
227230 head = [
228231 ("pool" , nn .AdaptiveAvgPool2d (1 )),
229232 ("flat" , nn .Flatten ()),
@@ -238,27 +241,27 @@ class ModelCfg(BaseModel):
238241 name : Optional [str ] = None
239242 in_chans : int = 3
240243 num_classes : int = 1000
241- block : Type [nn .Module ] = ResBlock
242- conv_layer : Type [nn .Module ] = ConvBnAct
243- block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
244- layers : List [int ] = [2 , 2 , 2 , 2 ]
245- norm : Type [nn .Module ] = nn .BatchNorm2d
246- act_fn : Type [nn .Module ] = nn .ReLU
244+ block : type [nn .Module ] = ResBlock
245+ conv_layer : type [nn .Module ] = ConvBnAct
246+ block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
247+ layers : list [int ] = [2 , 2 , 2 , 2 ]
248+ norm : type [nn .Module ] = nn .BatchNorm2d
249+ act_fn : type [nn .Module ] = nn .ReLU
247250 pool : Callable [[Any ], nn .Module ] = partial (
248251 nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
249252 )
250253 expansion : int = 1
251254 groups : int = 1
252255 dw : bool = False
253256 div_groups : Union [int , None ] = None
254- sa : Union [bool , int , Type [nn .Module ]] = False
255- se : Union [bool , int , Type [nn .Module ]] = False
257+ sa : Union [bool , int , type [nn .Module ]] = False
258+ se : Union [bool , int , type [nn .Module ]] = False
256259 se_module : Union [bool , None ] = None
257260 se_reduction : Union [int , None ] = None
258261 bn_1st : bool = True
259262 zero_bn : bool = True
260263 stem_stride_on : int = 0
261- stem_sizes : List [int ] = [32 , 32 , 64 ]
264+ stem_sizes : list [int ] = [32 , 32 , 64 ]
262265 stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
263266 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
264267 )
@@ -286,7 +289,7 @@ def _get_str_value(self, field: str) -> str:
286289 def __repr__ (self ) -> str :
287290 return f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )"
288291
289- def __repr_args__ (self ):
292+ def __repr_args__ (self ) -> list [ tuple [ str , str ]] :
290293 return [
291294 (field , str_value )
292295 for field in self .__fields__
@@ -325,7 +328,8 @@ def body(self):
325328 def from_cfg (cls , cfg : ModelCfg ):
326329 return cls (** cfg .dict ())
327330
328- def __call__ (self ):
331+ def __call__ (self ) -> nn .Sequential :
332+ """Create model."""
329333 model_name = self .name or self .__class__ .__name__
330334 named_sequential = type (model_name , (nn .Sequential ,), {})
331335 model = named_sequential (
@@ -338,13 +342,14 @@ def __call__(self):
338342 return model
339343
340344 def _get_extra_repr (self ) -> str :
345+ """Repr for changed fields"""
341346 return " " .join (
342347 f"{ field } : { self ._get_str_value (field )} ,"
343348 for field in self .__fields_set__
344349 if field != "name"
345- )[:- 1 ]
350+ )[:- 1 ] # strip last comma.
346351
347- def __repr__ (self ):
352+ def __repr__ (self ) -> str :
348353 se_repr = self .se .__name__ if self .se else "False" # type: ignore
349354 model_name = self .name or self .__class__ .__name__
350355 return (
0 commit comments