Skip to content

Commit a540481

Browse files
committed
refactor se sa validator
1 parent 6dbecde commit a540481

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

src/model_constructor/model_constructor.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, Optional, Union
44

55
from pydantic import field_validator
6+
from pydantic_core.core_schema import FieldValidationInfo
67
from torch import nn
78

89
from .blocks import BasicBlock, BottleneckBlock
@@ -17,6 +18,24 @@
1718
]
1819

1920

21+
DEFAULT_SE_SA = {
22+
"se": SEModule,
23+
"sa": SimpleSelfAttention,
24+
}
25+
26+
27+
def is_module(val: Any) -> bool:
28+
"""Check if val is a nn.Module or partial of nn.Module."""
29+
30+
to_check = val
31+
if isinstance(val, partial):
32+
to_check = val.func
33+
try:
34+
return issubclass(to_check, nn.Module)
35+
except TypeError:
36+
return False
37+
38+
2039
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
2140
"""Model constructor Config. As default - xresnet18"""
2241

@@ -35,7 +54,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3554
dw: bool = False
3655
div_groups: Optional[int] = None
3756
sa: Union[bool, type[nn.Module]] = False
38-
se: Union[bool, type[nn.Module]] = False
57+
se: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
3958
se_module: Optional[bool] = None
4059
se_reduction: Optional[int] = None
4160
bn_1st: bool = True
@@ -47,23 +66,24 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4766
)
4867
stem_bn_end: bool = False
4968

50-
@field_validator("se")
69+
@field_validator("se", "sa")
5170
def set_se( # pylint: disable=no-self-argument
52-
cls, value: Union[bool, type[nn.Module]]
53-
) -> Union[bool, type[nn.Module]]:
54-
if value:
55-
if isinstance(value, (int, bool)):
56-
return SEModule
57-
return value
58-
59-
@field_validator("sa")
60-
def set_sa( # pylint: disable=no-self-argument
61-
cls, value: Union[bool, type[nn.Module]]
62-
) -> Union[bool, type[nn.Module]]:
63-
if value:
64-
if isinstance(value, (int, bool)):
65-
return SimpleSelfAttention # default: ks=1, sym=sym
66-
return value
71+
cls, value: Union[bool, type[nn.Module]], info: FieldValidationInfo,
72+
) -> Union[type[nn.Module], Callable[[], nn.Module]]:
73+
if isinstance(value, (int, bool)):
74+
return DEFAULT_SE_SA[info.field_name]
75+
if is_module(value):
76+
return value
77+
raise ValueError(f"{info.field_name} must be bool or nn.Module")
78+
79+
# @field_validator("sa")
80+
# def set_sa( # pylint: disable=no-self-argument
81+
# cls, value: Union[bool, type[nn.Module]]
82+
# ) -> Union[bool, type[nn.Module]]:
83+
# if value:
84+
# if isinstance(value, (int, bool)):
85+
# return SimpleSelfAttention # default: ks=1, sym=sym
86+
# return value
6787

6888
@field_validator("se_module", "se_reduction") # pragma: no cover
6989
def deprecation_warning( # pylint: disable=no-self-argument

0 commit comments

Comments
 (0)