Skip to content

Commit de5ec7d

Browse files
committed
sa se
1 parent a540481 commit de5ec7d

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

src/model_constructor/model_constructor.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
5353
groups: int = 1
5454
dw: bool = False
5555
div_groups: Optional[int] = None
56-
sa: Union[bool, type[nn.Module]] = False
56+
sa: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
5757
se: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
5858
se_module: Optional[bool] = None
5959
se_reduction: Optional[int] = None
@@ -76,15 +76,6 @@ def set_se( # pylint: disable=no-self-argument
7676
return value
7777
raise ValueError(f"{info.field_name} must be bool or nn.Module")
7878

79-
# @field_validator("sa")
80-
# def set_sa( # pylint: disable=no-self-argument
81-
# cls, value: Union[bool, type[nn.Module]]
82-
# ) -> Union[bool, type[nn.Module]]:
83-
# if value:
84-
# if isinstance(value, (int, bool)):
85-
# return SimpleSelfAttention # default: ks=1, sym=sym
86-
# return value
87-
8879
@field_validator("se_module", "se_reduction") # pragma: no cover
8980
def deprecation_warning( # pylint: disable=no-self-argument
9081
cls, value: Union[bool, int, None]

0 commit comments

Comments
 (0)