11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , Optional , Union
3+ from typing import Any , Callable , Dict , List , Optional , Union , Type
44
55from pydantic import field_validator
66from pydantic_core .core_schema import FieldValidationInfo
77from torch import nn
88
99from .blocks import BasicBlock , BottleneckBlock
10- from .helpers import (Cfg , ListStrMod , ModSeq , init_cnn , instantiate_module ,
11- is_module , nn_seq )
10+ from .helpers import (
11+ Cfg ,
12+ ListStrMod ,
13+ ModSeq ,
14+ init_cnn ,
15+ instantiate_module ,
16+ is_module ,
17+ nn_seq ,
18+ )
1219from .layers import ConvBnAct , SEModule , SimpleSelfAttention
1320
1421__all__ = [
2532}
2633
2734
28- nnModule = Union [type [nn .Module ], Callable [[], nn .Module ]]
35+ nnModule = Union [Type [nn .Module ], Callable [[], nn .Module ]]
2936
3037
3138class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
@@ -36,8 +43,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3643 num_classes : int = 1000
3744 block : Union [nnModule , str ] = BasicBlock
3845 conv_layer : Union [nnModule , str ] = ConvBnAct
39- block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
40- layers : list [int ] = [2 , 2 , 2 , 2 ]
46+ block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
47+ layers : List [int ] = [2 , 2 , 2 , 2 ]
4148 norm : Union [nnModule , str ] = nn .BatchNorm2d
4249 act_fn : Union [nnModule , str ] = nn .ReLU
4350 pool : Union [nnModule , str , None ] = None
@@ -52,15 +59,16 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
5259 bn_1st : bool = True
5360 zero_bn : bool = True
5461 stem_stride_on : int = 0
55- stem_sizes : list [int ] = [64 ]
62+ stem_sizes : List [int ] = [64 ]
5663 stem_pool : Union [nnModule , str , None ] = partial (
5764 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
5865 )
5966 stem_bn_end : bool = False
6067
6168 @field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
6269 def set_modules ( # pylint: disable=no-self-argument
63- cls , value : Union [nnModule , str ],
70+ cls ,
71+ value : Union [nnModule , str ],
6472 ) -> nnModule :
6573 """Check values, if string, convert to nn.Module."""
6674 if is_module (value ):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
6977
7078 @field_validator ("se" , "sa" )
7179 def set_se ( # pylint: disable=no-self-argument
72- cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
80+ cls ,
81+ value : Union [bool , nnModule , str ],
82+ info : FieldValidationInfo ,
7383 ) -> nnModule :
7484 if isinstance (value , (int , bool )):
7585 return DEFAULT_SE_SA [info .field_name ]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154164
155165
156166def make_body (
157- cfg : ModelCfg ,
158- layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
167+ cfg : ModelCfg ,
168+ layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
159169) -> nn .Sequential :
160170 """Create model body."""
161171 if hasattr (cfg , "make_layer" ):
@@ -203,7 +213,7 @@ def from_cfg(cls, cfg: ModelCfg):
203213
204214 @classmethod
205215 def create_model (
206- cls , cfg : Optional [ModelCfg ] = None , ** kwargs : dict [str , Any ]
216+ cls , cfg : Optional [ModelCfg ] = None , ** kwargs : Dict [str , Any ]
207217 ) -> nn .Sequential :
208218 if cfg :
209219 return cls (** cfg .model_dump (exclude_none = True ))()
@@ -226,9 +236,9 @@ def __call__(self) -> nn.Sequential:
226236
227237
228238class ResNet34 (ModelConstructor ):
229- layers : list [int ] = [3 , 4 , 6 , 3 ]
239+ layers : List [int ] = [3 , 4 , 6 , 3 ]
230240
231241
232242class ResNet50 (ResNet34 ):
233- block : type [nn .Module ] = BottleneckBlock
234- block_sizes : list [int ] = [256 , 512 , 1024 , 2048 ]
243+ block : Type [nn .Module ] = BottleneckBlock
244+ block_sizes : List [int ] = [256 , 512 , 1024 , 2048 ]
0 commit comments