2525}
2626
2727
28- nnModule = Union [type [nn .Module ], Callable [[], nn .Module ], str ]
28+ nnModule = Union [type [nn .Module ], Callable [[], nn .Module ]]
2929
3030
3131class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
@@ -34,50 +34,48 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3434 name : Optional [str ] = None
3535 in_chans : int = 3
3636 num_classes : int = 1000
37- block : nnModule = BasicBlock
38- conv_layer : nnModule = ConvBnAct
37+ block : Union [ nnModule , str ] = BasicBlock
38+ conv_layer : Union [ nnModule , str ] = ConvBnAct
3939 block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
4040 layers : list [int ] = [2 , 2 , 2 , 2 ]
41- norm : nnModule = nn .BatchNorm2d
42- act_fn : nnModule = nn .ReLU
43- pool : Optional [nnModule ] = None
41+ norm : Union [ nnModule , str ] = nn .BatchNorm2d
42+ act_fn : Union [ nnModule , str ] = nn .ReLU
43+ pool : Union [nnModule , str , None ] = None
4444 expansion : int = 1
4545 groups : int = 1
4646 dw : bool = False
4747 div_groups : Optional [int ] = None
48- sa : Union [bool , type [ nn . Module ], Callable [[], nn . Module ] ] = False
49- se : Union [bool , type [ nn . Module ], Callable [[], nn . Module ] ] = False
48+ sa : Union [bool , nnModule , str ] = False
49+ se : Union [bool , nnModule , str ] = False
5050 se_module : Optional [bool ] = None
5151 se_reduction : Optional [int ] = None
5252 bn_1st : bool = True
5353 zero_bn : bool = True
5454 stem_stride_on : int = 0
5555 stem_sizes : list [int ] = [64 ]
56- stem_pool : Optional [nnModule ] = partial (
56+ stem_pool : Union [nnModule , str , None ] = partial (
5757 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
5858 )
5959 stem_bn_end : bool = False
6060
6161 @field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
6262 def set_modules ( # pylint: disable=no-self-argument
63- cls , value : Union [type [ nn . Module ] , str ], info : FieldValidationInfo ,
63+ cls , value : Union [nnModule , str ],
6464 ) -> nnModule :
6565 """Check values, if string, convert to nn.Module."""
6666 if is_module (value ):
6767 return value
68- if isinstance (value , str ):
69- return instantiate_module (value )
70- # raise ValueError(f"{info.field_name} must be str or nn.Module")
68+ return instantiate_module (value )
7169
7270 @field_validator ("se" , "sa" )
7371 def set_se ( # pylint: disable=no-self-argument
74- cls , value : Union [bool , type [ nn . Module ] ], info : FieldValidationInfo ,
72+ cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
7573 ) -> nnModule :
7674 if isinstance (value , (int , bool )):
7775 return DEFAULT_SE_SA [info .field_name ]
7876 if is_module (value ):
7977 return value
80- # raise ValueError(f"{info.field_name} must be bool or nn.Module") # no need - check at init
78+ return instantiate_module ( value )
8179
8280 @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
8381 def deprecation_warning ( # pylint: disable=no-self-argument
0 commit comments