Skip to content

Commit ae180ba

Browse files
committed
se, sa setter, remove root validator
1 parent 371cac0 commit ae180ba

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

src/model_constructor/model_constructor.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch.nn as nn
6-
from pydantic import BaseModel, root_validator
6+
from pydantic import BaseModel, validator
77

88
from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act
99

@@ -38,7 +38,7 @@ def __init__(
3838
in_channels: int,
3939
mid_channels: int,
4040
stride: int = 1,
41-
conv_layer=ConvBnAct,
41+
conv_layer: type[nn.Module] = ConvBnAct,
4242
act_fn: type[nn.Module] = nn.ReLU,
4343
zero_bn: bool = True,
4444
bn_1st: bool = True,
@@ -254,8 +254,8 @@ class ModelCfg(BaseModel):
254254
groups: int = 1
255255
dw: bool = False
256256
div_groups: Union[int, None] = None
257-
sa: Union[bool, int, type[nn.Module]] = False
258-
se: Union[bool, int, type[nn.Module]] = False
257+
sa: Union[bool, type[nn.Module]] = False
258+
se: Union[bool, type[nn.Module]] = False
259259
se_module: Union[bool, None] = None
260260
se_reduction: Union[int, None] = None
261261
bn_1st: bool = True
@@ -322,17 +322,26 @@ def print_changed(self) -> None:
322322
class ModelConstructor(ModelCfg):
323323
"""Model constructor. As default - xresnet18"""
324324

325-
@root_validator
326-
def post_init(cls, values): # pylint: disable=E0213
327-
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
328-
values["se"] = SEModule
329-
if values["sa"] and isinstance(values["sa"], (bool, int)): # if sa=1 or sa=True
330-
values["sa"] = SimpleSelfAttention # default: ks=1, sym=sym
331-
if values["se_module"] or values["se_reduction"]: # pragma: no cover
332-
print(
333-
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
334-
) # add deprecation warning.
335-
return values
325+
@validator("se")
326+
def set_se(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]:
327+
if value:
328+
if isinstance(value, (int, bool)):
329+
return SEModule
330+
return value
331+
332+
@validator("sa")
333+
def set_sa(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]:
334+
if value:
335+
if isinstance(value, (int, bool)):
336+
return SimpleSelfAttention # default: ks=1, sym=sym
337+
return value
338+
339+
@validator("se_module", "se_reduction")
340+
def deprecation_warning(cls, value): # pragma: no cover
341+
print(
342+
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
343+
)
344+
return value
336345

337346
@property
338347
def stem(self):

src/model_constructor/universal_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
in_channels: int,
2626
mid_channels: int,
2727
stride: int = 1,
28-
conv_layer=ConvBnAct,
28+
conv_layer: type[nn.Module] = ConvBnAct,
2929
act_fn: type[nn.Module] = nn.ReLU,
3030
zero_bn: bool = True,
3131
bn_1st: bool = True,

0 commit comments

Comments
 (0)