33from typing import Any , Callable , Optional , Union
44
55from pydantic import field_validator
6+ from pydantic_core .core_schema import FieldValidationInfo
67from torch import nn
78
89from .blocks import BasicBlock , BottleneckBlock
1718]
1819
1920
21+ DEFAULT_SE_SA = {
22+ "se" : SEModule ,
23+ "sa" : SimpleSelfAttention ,
24+ }
25+
26+
27+ def is_module (val : Any ) -> bool :
28+ """Check if val is a nn.Module or partial of nn.Module."""
29+
30+ to_check = val
31+ if isinstance (val , partial ):
32+ to_check = val .func
33+ try :
34+ return issubclass (to_check , nn .Module )
35+ except TypeError :
36+ return False
37+
38+
2039class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
2140 """Model constructor Config. As default - xresnet18"""
2241
@@ -35,7 +54,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3554 dw : bool = False
3655 div_groups : Optional [int ] = None
3756 sa : Union [bool , type [nn .Module ]] = False
38- se : Union [bool , type [nn .Module ]] = False
57+ se : Union [bool , type [nn .Module ], Callable [[], nn . Module ] ] = False
3958 se_module : Optional [bool ] = None
4059 se_reduction : Optional [int ] = None
4160 bn_1st : bool = True
@@ -47,23 +66,24 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4766 )
4867 stem_bn_end : bool = False
4968
50- @field_validator ("se" )
69+ @field_validator ("se" , "sa" )
5170 def set_se ( # pylint: disable=no-self-argument
52- cls , value : Union [bool , type [nn .Module ]]
53- ) -> Union [bool , type [nn .Module ]]:
54- if value :
55- if isinstance (value , (int , bool )):
56- return SEModule
57- return value
58-
59- @field_validator ("sa" )
60- def set_sa ( # pylint: disable=no-self-argument
61- cls , value : Union [bool , type [nn .Module ]]
62- ) -> Union [bool , type [nn .Module ]]:
63- if value :
64- if isinstance (value , (int , bool )):
65- return SimpleSelfAttention # default: ks=1, sym=sym
66- return value
71+ cls , value : Union [bool , type [nn .Module ]], info : FieldValidationInfo ,
72+ ) -> Union [type [nn .Module ], Callable [[], nn .Module ]]:
73+ if isinstance (value , (int , bool )):
74+ return DEFAULT_SE_SA [info .field_name ]
75+ if is_module (value ):
76+ return value
77+ raise ValueError (f"{ info .field_name } must be bool or nn.Module" )
78+
79+ # @field_validator("sa")
80+ # def set_sa( # pylint: disable=no-self-argument
81+ # cls, value: Union[bool, type[nn.Module]]
82+ # ) -> Union[bool, type[nn.Module]]:
83+ # if value:
84+ # if isinstance(value, (int, bool)):
85+ # return SimpleSelfAttention # default: ks=1, sym=sym
86+ # return value
6787
6888 @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
6989 def deprecation_warning ( # pylint: disable=no-self-argument
0 commit comments