77from torch import nn
88
99from .blocks import BasicBlock , BottleneckBlock
10- from .helpers import Cfg , ListStrMod , ModSeq , init_cnn , nn_seq
10+ from .helpers import (Cfg , ListStrMod , ModSeq , init_cnn , instantiate_module ,
11+ is_module , nn_seq )
1112from .layers import ConvBnAct , SEModule , SimpleSelfAttention
1213
1314__all__ = [
2425}
2526
2627
27- def is_module (val : Any ) -> bool :
28- """Check if val is a nn.Module or partial of nn.Module."""
29-
30- to_check = val
31- if isinstance (val , partial ):
32- to_check = val .func
33- try :
34- return issubclass (to_check , nn .Module )
35- except TypeError :
36- return False
28+ nnModule = Union [type [nn .Module ], Callable [[], nn .Module ], str ]
3729
3830
3931class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
@@ -42,13 +34,13 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4234 name : Optional [str ] = None
4335 in_chans : int = 3
4436 num_classes : int = 1000
45- block : type [ nn . Module ] = BasicBlock
46- conv_layer : type [ nn . Module ] = ConvBnAct
37+ block : nnModule = BasicBlock
38+ conv_layer : nnModule = ConvBnAct
4739 block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
4840 layers : list [int ] = [2 , 2 , 2 , 2 ]
49- norm : type [ nn . Module ] = nn .BatchNorm2d
50- act_fn : type [ nn . Module ] = nn .ReLU
51- pool : Optional [Callable [[ Any ], nn . Module ] ] = None
41+ norm : nnModule = nn .BatchNorm2d
42+ act_fn : nnModule = nn .ReLU
43+ pool : Optional [nnModule ] = None
5244 expansion : int = 1
5345 groups : int = 1
5446 dw : bool = False
@@ -61,11 +53,22 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
6153 zero_bn : bool = True
6254 stem_stride_on : int = 0
6355 stem_sizes : list [int ] = [64 ]
64- stem_pool : Optional [Callable [[], nn . Module ] ] = partial (
56+ stem_pool : Optional [nnModule ] = partial (
6557 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
6658 )
6759 stem_bn_end : bool = False
6860
61+ @field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
62+ def set_modules ( # pylint: disable=no-self-argument
63+ cls , value : Union [type [nn .Module ], str ], info : FieldValidationInfo ,
64+ ) -> Union [type [nn .Module ], Callable [[], nn .Module ]]:
65+ """Check values, if string, convert to nn.Module."""
66+ if is_module (value ):
67+ return value
68+ if isinstance (value , str ):
69+ return instantiate_module (value )
70+ raise ValueError (f"{ info .field_name } must be str or nn.Module" )
71+
6972 @field_validator ("se" , "sa" )
7073 def set_se ( # pylint: disable=no-self-argument
7174 cls , value : Union [bool , type [nn .Module ]], info : FieldValidationInfo ,
0 commit comments