@@ -42,7 +42,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4242 zero_bn : bool = True
4343 stem_stride_on : int = 0
4444 stem_sizes : list [int ] = [64 ]
45- stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
45+ stem_pool : Optional [Callable [[], nn .Module ]] = partial (
4646 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
4747 )
4848 stem_bn_end : bool = False
@@ -61,7 +61,7 @@ def __repr__(self) -> str:
6161 )
6262
6363
64- def make_stem (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
64+ def make_stem (cfg : ModelCfg ) -> nn .Sequential :
6565 """Create Resnet stem."""
6666 stem : ListStrMod = [
6767 (
@@ -116,15 +116,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
116116 )
117117
118118
119- def make_body (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
119+ def make_body (cfg : ModelCfg ) -> nn .Sequential :
120120 """Create model body."""
121121 return nn_seq (
122122 (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
123123 for layer_num in range (len (cfg .layers ))
124124 )
125125
126126
127- def make_head (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
127+ def make_head (cfg : ModelCfg ) -> nn .Sequential :
128128 """Create head."""
129129 head = [
130130 ("pool" , nn .AdaptiveAvgPool2d (1 )),
@@ -138,10 +138,10 @@ class ModelConstructor(ModelCfg):
138138 """Model constructor. As default - resnet18"""
139139
140140 init_cnn : Callable [[nn .Module ], None ] = init_cnn
141- make_stem : Callable [[ModelCfg ], ModSeq ] = make_stem # type: ignore
142- make_layer : Callable [[ModelCfg , int ], ModSeq ] = make_layer # type: ignore
143- make_body : Callable [[ModelCfg ], ModSeq ] = make_body # type: ignore
144- make_head : Callable [[ModelCfg ], ModSeq ] = make_head # type: ignore
141+ make_stem : Callable [[ModelCfg ], ModSeq ] = make_stem
142+ make_layer : Callable [[ModelCfg , int ], ModSeq ] = make_layer
143+ make_body : Callable [[ModelCfg ], ModSeq ] = make_body
144+ make_head : Callable [[ModelCfg ], ModSeq ] = make_head
145145
146146 @field_validator ("se" )
147147 def set_se ( # pylint: disable=no-self-argument
0 commit comments