11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , Optional , TypeVar , Union
3+ from typing import Any , Callable , Optional , Union
44
55from pydantic import field_validator
66from torch import nn
1717]
1818
1919
20- TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
21-
22-
2320class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
2421 """Model constructor Config. As default - xresnet18"""
2522
@@ -64,7 +61,7 @@ def __repr__(self) -> str:
6461 )
6562
6663
67- def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
64+ def make_stem (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
6865 """Create Resnet stem."""
6966 stem : ListStrMod = [
7067 (
@@ -88,7 +85,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
8885 return nn_seq (stem )
8986
9087
91- def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
88+ def make_layer (cfg : ModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
9289 """Create layer (stage)"""
9390 # if no pool on stem - stride = 2 for first layer block in body
9491 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
@@ -119,15 +116,15 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
119116 )
120117
121118
122- def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
119+ def make_body (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
123120 """Create model body."""
124121 return nn_seq (
125122 (f"l_{ layer_num } " , cfg .make_layer (cfg , layer_num )) # type: ignore
126123 for layer_num in range (len (cfg .layers ))
127124 )
128125
129126
130- def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
127+ def make_head (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
131128 """Create head."""
132129 head = [
133130 ("pool" , nn .AdaptiveAvgPool2d (1 )),
@@ -141,10 +138,10 @@ class ModelConstructor(ModelCfg):
141138 """Model constructor. As default - resnet18"""
142139
143140 init_cnn : Callable [[nn .Module ], None ] = init_cnn
144- make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
145- make_layer : Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer # type: ignore
146- make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
147- make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
141+ make_stem : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
142+ make_layer : Callable [[ModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer # type: ignore
143+ make_body : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
144+ make_head : Callable [[ModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
148145
149146 @field_validator ("se" )
150147 def set_se ( # pylint: disable=no-self-argument
0 commit comments