|
3 | 3 | from typing import Any, Callable, Optional, TypeVar, Union |
4 | 4 |
|
5 | 5 | import torch.nn as nn |
6 | | -from pydantic import BaseModel, root_validator |
| 6 | +from pydantic import BaseModel, validator |
7 | 7 |
|
8 | 8 | from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act |
9 | 9 |
|
@@ -38,7 +38,7 @@ def __init__( |
38 | 38 | in_channels: int, |
39 | 39 | mid_channels: int, |
40 | 40 | stride: int = 1, |
41 | | - conv_layer=ConvBnAct, |
| 41 | + conv_layer: type[nn.Module] = ConvBnAct, |
42 | 42 | act_fn: type[nn.Module] = nn.ReLU, |
43 | 43 | zero_bn: bool = True, |
44 | 44 | bn_1st: bool = True, |
@@ -254,8 +254,8 @@ class ModelCfg(BaseModel): |
254 | 254 | groups: int = 1 |
255 | 255 | dw: bool = False |
256 | 256 | div_groups: Union[int, None] = None |
257 | | - sa: Union[bool, int, type[nn.Module]] = False |
258 | | - se: Union[bool, int, type[nn.Module]] = False |
| 257 | + sa: Union[bool, type[nn.Module]] = False |
| 258 | + se: Union[bool, type[nn.Module]] = False |
259 | 259 | se_module: Union[bool, None] = None |
260 | 260 | se_reduction: Union[int, None] = None |
261 | 261 | bn_1st: bool = True |
@@ -322,17 +322,26 @@ def print_changed(self) -> None: |
322 | 322 | class ModelConstructor(ModelCfg): |
323 | 323 | """Model constructor. As default - xresnet18""" |
324 | 324 |
|
325 | | - @root_validator |
326 | | - def post_init(cls, values): # pylint: disable=E0213 |
327 | | - if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True |
328 | | - values["se"] = SEModule |
329 | | - if values["sa"] and isinstance(values["sa"], (bool, int)): # if sa=1 or sa=True |
330 | | - values["sa"] = SimpleSelfAttention # default: ks=1, sym=sym |
331 | | - if values["se_module"] or values["se_reduction"]: # pragma: no cover |
332 | | - print( |
333 | | - "Deprecated. Pass se_module as se argument, se_reduction as arg to se." |
334 | | - ) # add deprecation warning. |
335 | | - return values |
| 325 | + @validator("se") |
| 326 | + def set_se(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]: |
| 327 | + if value: |
| 328 | + if isinstance(value, (int, bool)): |
| 329 | + return SEModule |
| 330 | + return value |
| 331 | + |
| 332 | + @validator("sa") |
| 333 | + def set_sa(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]: |
| 334 | + if value: |
| 335 | + if isinstance(value, (int, bool)): |
| 336 | + return SimpleSelfAttention # default: ks=1, sym=sym |
| 337 | + return value |
| 338 | + |
| 339 | + @validator("se_module", "se_reduction") |
| 340 | + def deprecation_warning(cls, value): # pragma: no cover |
| 341 | + print( |
| 342 | + "Deprecated. Pass se_module as se argument, se_reduction as arg to se." |
| 343 | + ) |
| 344 | + return value |
336 | 345 |
|
337 | 346 | @property |
338 | 347 | def stem(self): |
|
0 commit comments