66from torch import nn
77
88from .blocks import BasicBlock , BottleneckBlock
9- from .helpers import Cfg , ListStrMod , init_cnn , nn_seq
9+ from .helpers import Cfg , ListStrMod , ModSeq , init_cnn , nn_seq
1010from .layers import ConvBnAct , SEModule , SimpleSelfAttention
1111
1212__all__ = [
@@ -33,20 +33,45 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3333 expansion : int = 1
3434 groups : int = 1
3535 dw : bool = False
36- div_groups : Union [int , None ] = None
36+ div_groups : Optional [int ] = None
3737 sa : Union [bool , type [nn .Module ]] = False
3838 se : Union [bool , type [nn .Module ]] = False
39- se_module : Union [bool , None ] = None
40- se_reduction : Union [int , None ] = None
39+ se_module : Optional [bool ] = None
40+ se_reduction : Optional [int ] = None
4141 bn_1st : bool = True
4242 zero_bn : bool = True
4343 stem_stride_on : int = 0
4444 stem_sizes : list [int ] = [64 ]
45- stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
45+ stem_pool : Optional [Callable [[], nn .Module ]] = partial (
4646 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
4747 )
4848 stem_bn_end : bool = False
4949
50+ @field_validator ("se" )
51+ def set_se ( # pylint: disable=no-self-argument
52+ cls , value : Union [bool , type [nn .Module ]]
53+ ) -> Union [bool , type [nn .Module ]]:
54+ if value :
55+ if isinstance (value , (int , bool )):
56+ return SEModule
57+ return value
58+
59+ @field_validator ("sa" )
60+ def set_sa ( # pylint: disable=no-self-argument
61+ cls , value : Union [bool , type [nn .Module ]]
62+ ) -> Union [bool , type [nn .Module ]]:
63+ if value :
64+ if isinstance (value , (int , bool )):
65+ return SimpleSelfAttention # default: ks=1, sym=sym
66+ return value
67+
68+ @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
69+ def deprecation_warning ( # pylint: disable=no-self-argument
70+ cls , value : Union [bool , int , None ]
71+ ) -> Union [bool , int , None ]:
72+ print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
73+ return value
74+
5075 def __repr__ (self ) -> str :
5176 se_repr = self .se .__name__ if self .se else "False" # type: ignore
5277 model_name = self .name or self .__class__ .__name__
@@ -61,7 +86,7 @@ def __repr__(self) -> str:
6186 )
6287
6388
64- def make_stem (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
89+ def make_stem (cfg : ModelCfg ) -> nn .Sequential :
6590 """Create Resnet stem."""
6691 stem : ListStrMod = [
6792 (
@@ -116,15 +141,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
116141 )
117142
118143
119- def make_body (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
144+ def make_body (cfg : ModelCfg ) -> nn .Sequential :
120145 """Create model body."""
121146 return nn_seq (
122147 (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
123148 for layer_num in range (len (cfg .layers ))
124149 )
125150
126151
127- def make_head (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
152+ def make_head (cfg : ModelCfg ) -> nn .Sequential :
128153 """Create head."""
129154 head = [
130155 ("pool" , nn .AdaptiveAvgPool2d (1 )),
@@ -138,35 +163,10 @@ class ModelConstructor(ModelCfg):
138163 """Model constructor. As default - resnet18"""
139164
140165 init_cnn : Callable [[nn .Module ], None ] = init_cnn
141- make_stem : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
142- make_layer : Callable [[ModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer # type: ignore
143- make_body : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
144- make_head : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
145-
146- @field_validator ("se" )
147- def set_se ( # pylint: disable=no-self-argument
148- cls , value : Union [bool , type [nn .Module ]]
149- ) -> Union [bool , type [nn .Module ]]:
150- if value :
151- if isinstance (value , (int , bool )):
152- return SEModule
153- return value
154-
155- @field_validator ("sa" )
156- def set_sa ( # pylint: disable=no-self-argument
157- cls , value : Union [bool , type [nn .Module ]]
158- ) -> Union [bool , type [nn .Module ]]:
159- if value :
160- if isinstance (value , (int , bool )):
161- return SimpleSelfAttention # default: ks=1, sym=sym
162- return value
163-
164- @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
165- def deprecation_warning ( # pylint: disable=no-self-argument
166- cls , value : Union [bool , int , None ]
167- ) -> Union [bool , int , None ]:
168- print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
169- return value
166+ make_stem : Callable [[ModelCfg ], ModSeq ] = make_stem
167+ make_layer : Callable [[ModelCfg , int ], ModSeq ] = make_layer
168+ make_body : Callable [[ModelCfg ], ModSeq ] = make_body
169+ make_head : Callable [[ModelCfg ], ModSeq ] = make_head
170170
171171 @property
172172 def stem (self ):
@@ -186,7 +186,7 @@ def from_cfg(cls, cfg: ModelCfg):
186186
187187 @classmethod
188188 def create_model (
189- cls , cfg : Union [ModelCfg , None ] = None , ** kwargs : dict [str , Any ]
189+ cls , cfg : Optional [ModelCfg ] = None , ** kwargs : dict [str , Any ]
190190 ) -> nn .Sequential :
191191 if cfg :
192192 return cls (** cfg .model_dump ())()
0 commit comments