@@ -129,7 +129,7 @@ def forward(self, x):
129129 return self .act_fn (self .convs (x ) + identity )
130130
131131
132- def make_stem (cfg : TModelCfg ) -> nn .Sequential :
132+ def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
133133 stem : List [tuple [str , nn .Module ]] = [
134134 (f"conv_{ i } " , cfg .conv_layer (
135135 cfg .stem_sizes [i ], # type: ignore
@@ -150,7 +150,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential:
150150 return nn .Sequential (OrderedDict (stem ))
151151
152152
153- def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential :
153+ def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
154154 # expansion, in_channels, out_channels, blocks, stride, sa):
155155 # if no pool on stem - stride = 2 for first layer block in body
156156 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
@@ -186,7 +186,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential:
186186 )
187187
188188
189- def make_body (cfg : TModelCfg ) -> nn .Sequential :
189+ def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
190190 return nn .Sequential (
191191 OrderedDict (
192192 [
@@ -200,7 +200,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential:
200200 )
201201
202202
203- def make_head (cfg : TModelCfg ) -> nn .Sequential :
203+ def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
204204 head = [
205205 ("pool" , nn .AdaptiveAvgPool2d (1 )),
206206 ("flat" , nn .Flatten ()),
@@ -237,10 +237,10 @@ class ModelCfg(BaseModel):
237237 stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
238238 stem_bn_end : bool = False
239239 init_cnn : Callable [[nn .Module ], None ] = init_cnn
240- make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem
241- make_layer : Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer
242- make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body
243- make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head
240+ make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
241+ make_layer : Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer # type: ignore
242+ make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
243+ make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
244244
245245 class Config :
246246 arbitrary_types_allowed = True
0 commit comments