@@ -47,6 +47,31 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4747 )
4848 stem_bn_end : bool = False
4949
50+ @field_validator ("se" )
51+ def set_se ( # pylint: disable=no-self-argument
52+ cls , value : Union [bool , type [nn .Module ]]
53+ ) -> Union [bool , type [nn .Module ]]:
54+ if value :
55+ if isinstance (value , (int , bool )):
56+ return SEModule
57+ return value
58+
59+ @field_validator ("sa" )
60+ def set_sa ( # pylint: disable=no-self-argument
61+ cls , value : Union [bool , type [nn .Module ]]
62+ ) -> Union [bool , type [nn .Module ]]:
63+ if value :
64+ if isinstance (value , (int , bool )):
65+ return SimpleSelfAttention # default: ks=1, sym=sym
66+ return value
67+
68+ @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
69+ def deprecation_warning ( # pylint: disable=no-self-argument
70+ cls , value : Union [bool , int , None ]
71+ ) -> Union [bool , int , None ]:
72+ print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
73+ return value
74+
5075 def __repr__ (self ) -> str :
5176 se_repr = self .se .__name__ if self .se else "False" # type: ignore
5277 model_name = self .name or self .__class__ .__name__
@@ -143,31 +168,6 @@ class ModelConstructor(ModelCfg):
143168 make_body : Callable [[ModelCfg ], ModSeq ] = make_body
144169 make_head : Callable [[ModelCfg ], ModSeq ] = make_head
145170
146- @field_validator ("se" )
147- def set_se ( # pylint: disable=no-self-argument
148- cls , value : Union [bool , type [nn .Module ]]
149- ) -> Union [bool , type [nn .Module ]]:
150- if value :
151- if isinstance (value , (int , bool )):
152- return SEModule
153- return value
154-
155- @field_validator ("sa" )
156- def set_sa ( # pylint: disable=no-self-argument
157- cls , value : Union [bool , type [nn .Module ]]
158- ) -> Union [bool , type [nn .Module ]]:
159- if value :
160- if isinstance (value , (int , bool )):
161- return SimpleSelfAttention # default: ks=1, sym=sym
162- return value
163-
164- @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
165- def deprecation_warning ( # pylint: disable=no-self-argument
166- cls , value : Union [bool , int , None ]
167- ) -> Union [bool , int , None ]:
168- print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." )
169- return value
170-
171171 @property
172172 def stem (self ):
173173 return self .make_stem (self ) # pylint: disable=too-many-function-args
0 commit comments