11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3+ from typing import Any , Callable , List , Type , TypeVar , Union
44
55import torch .nn as nn
66from pydantic import BaseModel , root_validator
@@ -236,11 +236,11 @@ class ModelCfg(BaseModel):
236236 stem_sizes : List [int ] = [32 , 32 , 64 ]
237237 stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
238238 stem_bn_end : bool = False
239- init_cnn : Optional [ Callable [[nn .Module ], None ] ] = init_cnn
240- make_stem : Optional [ Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ] ]] = make_stem
241- make_layer : Optional [ Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ] ]] = make_layer
242- make_body : Optional [ Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ] ]] = make_body
243- make_head : Optional [ Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ] ]] = make_head
239+ init_cnn : Callable [[nn .Module ], None ] = init_cnn
240+ make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem
241+ make_layer : Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer
242+ make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body
243+ make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head
244244
245245 class Config :
246246 arbitrary_types_allowed = True
@@ -262,17 +262,6 @@ class ModelConstructor(ModelCfg):
262262
263263 @root_validator
264264 def post_init (cls , values ): # pylint: disable=E0213
265- # if values["init_cnn"] is None:
266- # values["init_cnn"] = init_cnn
267- # if values["make_stem"] is None:
268- # values["make_stem"] = make_stem
269- # if values["make_layer"] is None:
270- # values["make_layer"] = make_layer
271- # if values["make_body"] is None:
272- # values["make_body"] = make_body
273- # if values["make_head"] is None:
274- # values["make_head"] = make_head
275-
276265 if values ["stem_sizes" ][0 ] != values ["in_chans" ]:
277266 values ["stem_sizes" ] = [values ["in_chans" ]] + values ["stem_sizes" ]
278267 if values ["se" ] and isinstance (values ["se" ], (bool , int )): # if se=1 or se=True
@@ -287,15 +276,15 @@ def post_init(cls, values): # pylint: disable=E0213
287276
288277 @property
289278 def stem (self ):
290- return self .make_stem (self ) # type: ignore
279+ return self .make_stem (self ) # pylint: disable=too-many-function-args
291280
292281 @property
293282 def head (self ):
294- return self .make_head (self ) # type: ignore
283+ return self .make_head (self ) # pylint: disable=too-many-function-args
295284
296285 @property
297286 def body (self ):
298- return self .make_body (self ) # type: ignore
287+ return self .make_body (self ) # pylint: disable=too-many-function-args
299288
300289 @classmethod
301290 def from_cfg (cls , cfg : ModelCfg ):
@@ -305,12 +294,12 @@ def __call__(self):
305294 model = nn .Sequential (
306295 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
307296 )
308- self .init_cnn (model ) # type: ignore
297+ self .init_cnn (model ) # pylint: disable=too-many-function-args
309298 model .extra_repr = lambda : f"{ self .name } "
310299 return model
311300
312301 def __repr__ (self ):
313- se_repr = self .se .__name__ if self .se else "False"
302+ se_repr = self .se .__name__ if self .se else "False" # type: ignore
314303 return (
315304 f"{ self .name } constructor\n "
316305 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
0 commit comments