Skip to content

Commit e3d713d

Browse files
authored
Merge pull request #95 from ayasyrev/pydantic_v2
Pydantic v2
2 parents 19da428 + 5e4d304 commit e3d713d

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
pydantic
12
# pytorch
23
# numpy

src/model_constructor/model_constructor.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch
6-
from pydantic import BaseModel, validator
6+
from pydantic import BaseModel, field_validator
77
from torch import nn
88

99
from .helpers import nn_seq
@@ -35,7 +35,7 @@ def init_cnn(module: nn.Module) -> None:
3535

3636
class BasicBlock(nn.Module):
3737
"""Basic Resnet block.
38-
Configurable - can use pool to reduce at identity path, change act etc. """
38+
Configurable - can use pool to reduce at identity path, change act etc."""
3939

4040
def __init__(
4141
self,
@@ -119,7 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
119119

120120
class BottleneckBlock(nn.Module):
121121
"""Bottleneck Resnet block.
122-
Configurable - can use pool to reduce at identity path, change act etc. """
122+
Configurable - can use pool to reduce at identity path, change act etc."""
123123

124124
def __init__(
125125
self,
@@ -285,7 +285,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
285285
return nn_seq(head)
286286

287287

288-
class ModelCfg(BaseModel):
288+
class ModelCfg(BaseModel, arbitrary_types_allowed=True, extra="forbid"):
289289
"""Model constructor Config. As default - xresnet18"""
290290

291291
name: Optional[str] = None
@@ -320,10 +320,6 @@ class ModelCfg(BaseModel):
320320
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
321321
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
322322

323-
class Config: # pylint: disable=too-few-public-methods
324-
arbitrary_types_allowed = True
325-
extra = "forbid"
326-
327323
def _get_str_value(self, field: str) -> str:
328324
value = getattr(self, field)
329325
if isinstance(value, type):
@@ -340,15 +336,15 @@ def __repr__(self) -> str:
340336
def __repr_args__(self) -> list[tuple[str, str]]:
341337
return [
342338
(field, str_value)
343-
for field in self.__fields__
339+
for field in self.model_fields
344340
if (str_value := self._get_str_value(field))
345341
]
346342

347343
def __repr_changed_args__(self) -> list[str]:
348344
"""Return list repr for changed fields"""
349345
return [
350346
f"{field}: {self._get_str_value(field)}"
351-
for field in self.__fields_set__
347+
for field in self.model_fields_set
352348
if field != "name"
353349
]
354350

@@ -370,7 +366,7 @@ def print_changed(self) -> None:
370366
class ModelConstructor(ModelCfg):
371367
"""Model constructor. As default - resnet18"""
372368

373-
@validator("se")
369+
@field_validator("se")
374370
def set_se( # pylint: disable=no-self-argument
375371
cls, value: Union[bool, type[nn.Module]]
376372
) -> Union[bool, type[nn.Module]]:
@@ -379,7 +375,7 @@ def set_se( # pylint: disable=no-self-argument
379375
return SEModule
380376
return value
381377

382-
@validator("sa")
378+
@field_validator("sa")
383379
def set_sa( # pylint: disable=no-self-argument
384380
cls, value: Union[bool, type[nn.Module]]
385381
) -> Union[bool, type[nn.Module]]:
@@ -388,7 +384,7 @@ def set_sa( # pylint: disable=no-self-argument
388384
return SimpleSelfAttention # default: ks=1, sym=sym
389385
return value
390386

391-
@validator("se_module", "se_reduction") # pragma: no cover
387+
@field_validator("se_module", "se_reduction") # pragma: no cover
392388
def deprecation_warning( # pylint: disable=no-self-argument
393389
cls, value: Union[bool, int, None]
394390
) -> Union[bool, int, None]:
@@ -409,12 +405,14 @@ def body(self):
409405

410406
@classmethod
411407
def from_cfg(cls, cfg: ModelCfg):
412-
return cls(**cfg.dict())
408+
return cls(**cfg.model_dump())
413409

414410
@classmethod
415-
def create_model(cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]) -> nn.Sequential:
411+
def create_model(
412+
cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]
413+
) -> nn.Sequential:
416414
if cfg:
417-
return cls(**cfg.dict())()
415+
return cls(**cfg.model_dump())()
418416
return cls(**kwargs)()
419417

420418
def __call__(self) -> nn.Sequential:

0 commit comments

Comments
 (0)