1919TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
2020
2121
22- def init_cnn (module : nn .Module ):
22+ def init_cnn (module : nn .Module ) -> None :
2323 "Init module - kaiming_normal for Conv2d and 0 for biases."
2424 if getattr (module , "bias" , None ) is not None :
2525 nn .init .constant_ (module .bias , 0 ) # type: ignore
@@ -144,14 +144,15 @@ def __init__(
144144 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
145145 else :
146146 self .id_conv = None
147- self .act_fn = get_act (act_fn ) # type: ignore
147+ self .act_fn = get_act (act_fn )
148148
149149 def forward (self , x ):
150150 identity = self .id_conv (x ) if self .id_conv is not None else x
151151 return self .act_fn (self .convs (x ) + identity )
152152
153153
154154def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155+ """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155156 len_stem = len (cfg .stem_sizes )
156157 stem : list [tuple [str , nn .Module ]] = [
157158 (
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175176
176177
177178def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
178- # expansion, in_channels, out_channels, blocks, stride, sa):
179+ """Create layer (stage)"""
179180 # if no pool on stem - stride = 2 for first layer block in body
180181 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
181182 num_blocks = cfg .layers [layer_num ]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213214
214215
215216def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
217+ """Create model body."""
216218 return nn .Sequential (
217219 OrderedDict (
218220 [
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224226
225227
226228def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
229+ """Create head."""
227230 head = [
228231 ("pool" , nn .AdaptiveAvgPool2d (1 )),
229232 ("flat" , nn .Flatten ()),
@@ -326,6 +329,7 @@ def from_cfg(cls, cfg: ModelCfg):
326329 return cls (** cfg .dict ())
327330
328331 def __call__ (self ) -> nn .Sequential :
332+ """Create model."""
329333 model_name = self .name or self .__class__ .__name__
330334 named_sequential = type (model_name , (nn .Sequential ,), {})
331335 model = named_sequential (
0 commit comments