33from typing import Any , Callable , Optional , TypeVar , Union
44
55import torch
6- from pydantic import BaseModel , validator
6+ from pydantic import BaseModel , field_validator
77from torch import nn
88
99from .helpers import nn_seq
@@ -285,7 +285,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
285285 return nn_seq (head )
286286
287287
288- class ModelCfg (BaseModel ):
288+ class ModelCfg (BaseModel , arbitrary_types_allowed = True , extra = "forbid" ):
289289 """Model constructor Config. As default - xresnet18"""
290290
291291 name : Optional [str ] = None
@@ -320,9 +320,6 @@ class ModelCfg(BaseModel):
320320 make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
321321 make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
322322
323- class Config : # pylint: disable=too-few-public-methods
324- arbitrary_types_allowed = True
325- extra = "forbid"
326323
327324 def _get_str_value (self , field : str ) -> str :
328325 value = getattr (self , field )
@@ -340,15 +337,15 @@ def __repr__(self) -> str:
340337 def __repr_args__ (self ) -> list [tuple [str , str ]]:
341338 return [
342339 (field , str_value )
343- for field in self .__fields__
340+ for field in self .model_fields
344341 if (str_value := self ._get_str_value (field ))
345342 ]
346343
347344 def __repr_changed_args__ (self ) -> list [str ]:
348345 """Return list repr for changed fields"""
349346 return [
350347 f"{ field } : { self ._get_str_value (field )} "
351- for field in self .__fields_set__
348+ for field in self .model_fields
352349 if field != "name"
353350 ]
354351
@@ -370,7 +367,7 @@ def print_changed(self) -> None:
370367class ModelConstructor (ModelCfg ):
371368 """Model constructor. As default - resnet18"""
372369
373- @validator ("se" )
370+ @field_validator ("se" )
374371 def set_se ( # pylint: disable=no-self-argument
375372 cls , value : Union [bool , type [nn .Module ]]
376373 ) -> Union [bool , type [nn .Module ]]:
@@ -379,7 +376,7 @@ def set_se( # pylint: disable=no-self-argument
379376 return SEModule
380377 return value
381378
382- @validator ("sa" )
379+ @field_validator ("sa" )
383380 def set_sa ( # pylint: disable=no-self-argument
384381 cls , value : Union [bool , type [nn .Module ]]
385382 ) -> Union [bool , type [nn .Module ]]:
@@ -388,7 +385,7 @@ def set_sa( # pylint: disable=no-self-argument
388385 return SimpleSelfAttention # default: ks=1, sym=sym
389386 return value
390387
391- @validator ("se_module" , "se_reduction" ) # pragma: no cover
388+ @field_validator ("se_module" , "se_reduction" ) # pragma: no cover
392389 def deprecation_warning ( # pylint: disable=no-self-argument
393390 cls , value : Union [bool , int , None ]
394391 ) -> Union [bool , int , None ]:
@@ -409,12 +406,12 @@ def body(self):
409406
410407 @classmethod
411408 def from_cfg (cls , cfg : ModelCfg ):
412- return cls (** cfg .dict ())
409+ return cls (** cfg .model_dump ())
413410
414411 @classmethod
415412 def create_model (cls , cfg : Union [ModelCfg , None ] = None , ** kwargs : dict [str , Any ]) -> nn .Sequential :
416413 if cfg :
417- return cls (** cfg .dict ())()
414+ return cls (** cfg .model_dump ())()
418415 return cls (** kwargs )()
419416
420417 def __call__ (self ) -> nn .Sequential :
0 commit comments