@@ -53,7 +53,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
5353 groups : int = 1
5454 dw : bool = False
5555 div_groups : Optional [int ] = None
56- sa : Union [bool , type [nn .Module ]] = False
56+ sa : Union [bool , type [nn .Module ], Callable [[], nn . Module ] ] = False
5757 se : Union [bool , type [nn .Module ], Callable [[], nn .Module ]] = False
5858 se_module : Optional [bool ] = None
5959 se_reduction : Optional [int ] = None
@@ -76,15 +76,6 @@ def set_se( # pylint: disable=no-self-argument
7676 return value
7777 raise ValueError (f"{ info .field_name } must be bool or nn.Module" )
7878
79- # @field_validator("sa")
80- # def set_sa( # pylint: disable=no-self-argument
81- # cls, value: Union[bool, type[nn.Module]]
82- # ) -> Union[bool, type[nn.Module]]:
83- # if value:
84- # if isinstance(value, (int, bool)):
85- # return SimpleSelfAttention # default: ks=1, sym=sym
86- # return value
87-
8879 @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
8980 def deprecation_warning ( # pylint: disable=no-self-argument
9081 cls , value : Union [bool , int , None ]
0 commit comments