33from typing import Any , Callable , Optional , TypeVar , Union
44
55import torch
6- from pydantic import BaseModel , validator
6+ from pydantic import BaseModel , field_validator
77from torch import nn
88
99from .helpers import nn_seq
@@ -35,7 +35,7 @@ def init_cnn(module: nn.Module) -> None:
3535
3636class BasicBlock (nn .Module ):
3737 """Basic Resnet block.
38- Configurable - can use pool to reduce at identity path, change act etc. """
38+ Configurable - can use pool to reduce at identity path, change act etc."""
3939
4040 def __init__ (
4141 self ,
@@ -119,7 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
119119
120120class BottleneckBlock (nn .Module ):
121121 """Bottleneck Resnet block.
122- Configurable - can use pool to reduce at identity path, change act etc. """
122+ Configurable - can use pool to reduce at identity path, change act etc."""
123123
124124 def __init__ (
125125 self ,
@@ -285,7 +285,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
285285 return nn_seq (head )
286286
287287
288- class ModelCfg (BaseModel ):
288+ class ModelCfg (BaseModel , arbitrary_types_allowed = True , extra = "forbid" ):
289289 """Model constructor Config. As default - xresnet18"""
290290
291291 name : Optional [str ] = None
@@ -320,10 +320,6 @@ class ModelCfg(BaseModel):
320320 make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
321321 make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
322322
323- class Config : # pylint: disable=too-few-public-methods
324- arbitrary_types_allowed = True
325- extra = "forbid"
326-
327323 def _get_str_value (self , field : str ) -> str :
328324 value = getattr (self , field )
329325 if isinstance (value , type ):
@@ -340,15 +336,15 @@ def __repr__(self) -> str:
340336 def __repr_args__ (self ) -> list [tuple [str , str ]]:
341337 return [
342338 (field , str_value )
343- for field in self .__fields__
339+ for field in self .model_fields
344340 if (str_value := self ._get_str_value (field ))
345341 ]
346342
347343 def __repr_changed_args__ (self ) -> list [str ]:
348344 """Return list repr for changed fields"""
349345 return [
350346 f"{ field } : { self ._get_str_value (field )} "
351- for field in self .__fields_set__
347+ for field in self .model_fields_set
352348 if field != "name"
353349 ]
354350
@@ -370,7 +366,7 @@ def print_changed(self) -> None:
370366class ModelConstructor (ModelCfg ):
371367 """Model constructor. As default - resnet18"""
372368
373- @validator ("se" )
369+ @field_validator ("se" )
374370 def set_se ( # pylint: disable=no-self-argument
375371 cls , value : Union [bool , type [nn .Module ]]
376372 ) -> Union [bool , type [nn .Module ]]:
@@ -379,7 +375,7 @@ def set_se( # pylint: disable=no-self-argument
379375 return SEModule
380376 return value
381377
382- @validator ("sa" )
378+ @field_validator ("sa" )
383379 def set_sa ( # pylint: disable=no-self-argument
384380 cls , value : Union [bool , type [nn .Module ]]
385381 ) -> Union [bool , type [nn .Module ]]:
@@ -388,7 +384,7 @@ def set_sa( # pylint: disable=no-self-argument
388384 return SimpleSelfAttention # default: ks=1, sym=sym
389385 return value
390386
391- @validator ("se_module" , "se_reduction" ) # pragma: no cover
387+ @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
392388 def deprecation_warning ( # pylint: disable=no-self-argument
393389 cls , value : Union [bool , int , None ]
394390 ) -> Union [bool , int , None ]:
@@ -409,12 +405,14 @@ def body(self):
409405
410406 @classmethod
411407 def from_cfg (cls , cfg : ModelCfg ):
412- return cls (** cfg .dict ())
408+ return cls (** cfg .model_dump ())
413409
414410 @classmethod
415- def create_model (cls , cfg : Union [ModelCfg , None ] = None , ** kwargs : dict [str , Any ]) -> nn .Sequential :
411+ def create_model (
412+ cls , cfg : Union [ModelCfg , None ] = None , ** kwargs : dict [str , Any ]
413+ ) -> nn .Sequential :
416414 if cfg :
417- return cls (** cfg .dict ())()
415+ return cls (** cfg .model_dump ())()
418416 return cls (** kwargs )()
419417
420418 def __call__ (self ) -> nn .Sequential :
0 commit comments