Skip to content

Commit aa4f6a1

Browse files
committed
refactor to pydantic v2
1 parent 19da428 commit aa4f6a1

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
pydantic
12
# pytorch
23
# numpy

src/model_constructor/model_constructor.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch
6-
from pydantic import BaseModel, validator
6+
from pydantic import BaseModel, field_validator
77
from torch import nn
88

99
from .helpers import nn_seq
@@ -285,7 +285,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
285285
return nn_seq(head)
286286

287287

288-
class ModelCfg(BaseModel):
288+
class ModelCfg(BaseModel, arbitrary_types_allowed=True, extra="forbid"):
289289
"""Model constructor Config. As default - xresnet18"""
290290

291291
name: Optional[str] = None
@@ -320,9 +320,6 @@ class ModelCfg(BaseModel):
320320
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
321321
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
322322

323-
class Config: # pylint: disable=too-few-public-methods
324-
arbitrary_types_allowed = True
325-
extra = "forbid"
326323

327324
def _get_str_value(self, field: str) -> str:
328325
value = getattr(self, field)
@@ -340,15 +337,15 @@ def __repr__(self) -> str:
340337
def __repr_args__(self) -> list[tuple[str, str]]:
341338
return [
342339
(field, str_value)
343-
for field in self.__fields__
340+
for field in self.model_fields
344341
if (str_value := self._get_str_value(field))
345342
]
346343

347344
def __repr_changed_args__(self) -> list[str]:
348345
"""Return list repr for changed fields"""
349346
return [
350347
f"{field}: {self._get_str_value(field)}"
351-
for field in self.__fields_set__
348+
for field in self.model_fields
352349
if field != "name"
353350
]
354351

@@ -370,7 +367,7 @@ def print_changed(self) -> None:
370367
class ModelConstructor(ModelCfg):
371368
"""Model constructor. As default - resnet18"""
372369

373-
@validator("se")
370+
@field_validator("se")
374371
def set_se( # pylint: disable=no-self-argument
375372
cls, value: Union[bool, type[nn.Module]]
376373
) -> Union[bool, type[nn.Module]]:
@@ -379,7 +376,7 @@ def set_se( # pylint: disable=no-self-argument
379376
return SEModule
380377
return value
381378

382-
@validator("sa")
379+
@field_validator("sa")
383380
def set_sa( # pylint: disable=no-self-argument
384381
cls, value: Union[bool, type[nn.Module]]
385382
) -> Union[bool, type[nn.Module]]:
@@ -388,7 +385,7 @@ def set_sa( # pylint: disable=no-self-argument
388385
return SimpleSelfAttention # default: ks=1, sym=sym
389386
return value
390387

391-
@validator("se_module", "se_reduction") # pragma: no cover
388+
@field_validator("se_module", "se_reduction") # pragma: no cover
392389
def deprecation_warning( # pylint: disable=no-self-argument
393390
cls, value: Union[bool, int, None]
394391
) -> Union[bool, int, None]:
@@ -409,12 +406,12 @@ def body(self):
409406

410407
@classmethod
411408
def from_cfg(cls, cfg: ModelCfg):
412-
return cls(**cfg.dict())
409+
return cls(**cfg.model_dump())
413410

414411
@classmethod
415412
def create_model(cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]) -> nn.Sequential:
416413
if cfg:
417-
return cls(**cfg.dict())()
414+
return cls(**cfg.model_dump())()
418415
return cls(**kwargs)()
419416

420417
def __call__(self) -> nn.Sequential:

0 commit comments

Comments
 (0)