11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , List , Optional , Type , Union
3+ from typing import Any , Callable , List , Type , TypeVar , Union
44
55import torch .nn as nn
66from pydantic import BaseModel , root_validator
1616]
1717
1818
19+ TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
20+
21+
22+ def init_cnn (module : nn .Module ):
23+ "Init module - kaiming_normal for Conv2d and 0 for biases."
24+ if getattr (module , "bias" , None ) is not None :
25+ nn .init .constant_ (module .bias , 0 ) # type: ignore
26+ if isinstance (module , (nn .Conv2d , nn .Linear )):
27+ nn .init .kaiming_normal_ (module .weight )
28+ for layer in module .children ():
29+ init_cnn (layer )
30+
31+
1932class ResBlock (nn .Module ):
2033 """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2134
@@ -116,86 +129,28 @@ def forward(self, x):
116129 return self .act_fn (self .convs (x ) + identity )
117130
118131
119- class ModelCfg (BaseModel ):
120- """Model constructor Config. As default - xresnet18"""
121-
122- name : str = "MC"
123- in_chans : int = 3
124- num_classes : int = 1000
125- block : Type [nn .Module ] = ResBlock
126- conv_layer : Type [nn .Module ] = ConvBnAct
127- block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
128- layers : List [int ] = [2 , 2 , 2 , 2 ]
129- norm : Type [nn .Module ] = nn .BatchNorm2d
130- act_fn : Type [nn .Module ] = nn .ReLU
131- pool : Callable [[Any ], nn .Module ] = partial (nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
132- expansion : int = 1
133- groups : int = 1
134- dw : bool = False
135- div_groups : Union [int , None ] = None
136- sa : Union [bool , int , Type [nn .Module ]] = False
137- se : Union [bool , int , Type [nn .Module ]] = False
138- se_module : Union [bool , None ] = None
139- se_reduction : Union [int , None ] = None
140- bn_1st : bool = True
141- zero_bn : bool = True
142- stem_stride_on : int = 0
143- stem_sizes : List [int ] = [32 , 32 , 64 ]
144- stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
145- stem_bn_end : bool = False
146- init_cnn : Optional [Callable [[nn .Module ], None ]] = None
147- make_stem : Optional [Callable [["ModelCfg" ], nn .Module ]] = None
148- make_layer : Optional [Callable [["ModelCfg" ], nn .Module ]] = None
149- make_body : Optional [Callable [["ModelCfg" ], nn .Module ]] = None
150- make_head : Optional [Callable [["ModelCfg" ], nn .Module ]] = None
151-
152- class Config :
153- arbitrary_types_allowed = True
154- extra = "forbid"
155-
156- def extra_repr (self ) -> str :
157- res = ""
158- for k , v in self .dict ().items ():
159- if v is not None :
160- res += f"{ k } : { v } \n "
161- return res
162-
163- def pprint (self ) -> None :
164- print (self .extra_repr ())
165-
166-
167- def init_cnn (module : nn .Module ):
168- "Init module - kaiming_normal for Conv2d and 0 for biases."
169- if getattr (module , "bias" , None ) is not None :
170- nn .init .constant_ (module .bias , 0 ) # type: ignore
171- if isinstance (module , (nn .Conv2d , nn .Linear )):
172- nn .init .kaiming_normal_ (module .weight )
173- for layer in module .children ():
174- init_cnn (layer )
175-
176-
177- def make_stem (self : ModelCfg ) -> nn .Sequential :
132+ def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
178133 stem : List [tuple [str , nn .Module ]] = [
179- (f"conv_{ i } " , self .conv_layer (
180- self .stem_sizes [i ], # type: ignore
181- self .stem_sizes [i + 1 ],
182- stride = 2 if i == self .stem_stride_on else 1 ,
183- bn_layer = (not self .stem_bn_end )
184- if i == (len (self .stem_sizes ) - 2 )
134+ (f"conv_{ i } " , cfg .conv_layer (
135+ cfg .stem_sizes [i ], # type: ignore
136+ cfg .stem_sizes [i + 1 ],
137+ stride = 2 if i == cfg .stem_stride_on else 1 ,
138+ bn_layer = (not cfg .stem_bn_end )
139+ if i == (len (cfg .stem_sizes ) - 2 )
185140 else True ,
186- act_fn = self .act_fn ,
187- bn_1st = self .bn_1st ,
141+ act_fn = cfg .act_fn ,
142+ bn_1st = cfg .bn_1st ,
188143 ),)
189- for i in range (len (self .stem_sizes ) - 1 )
144+ for i in range (len (cfg .stem_sizes ) - 1 )
190145 ]
191- if self .stem_pool :
192- stem .append (("stem_pool" , self .stem_pool ()))
193- if self .stem_bn_end :
194- stem .append (("norm" , self .norm (self .stem_sizes [- 1 ]))) # type: ignore
146+ if cfg .stem_pool :
147+ stem .append (("stem_pool" , cfg .stem_pool ()))
148+ if cfg .stem_bn_end :
149+ stem .append (("norm" , cfg .norm (cfg .stem_sizes [- 1 ]))) # type: ignore
195150 return nn .Sequential (OrderedDict (stem ))
196151
197152
198- def make_layer (cfg : ModelCfg , layer_num : int ) -> nn .Sequential :
153+ def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
199154 # expansion, in_channels, out_channels, blocks, stride, sa):
200155 # if no pool on stem - stride = 2 for first layer block in body
201156 stride = 1 if cfg .stem_pool and layer_num == 0 else 2
@@ -231,7 +186,7 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
231186 )
232187
233188
234- def make_body (cfg : ModelCfg ) -> nn .Sequential :
189+ def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
235190 return nn .Sequential (
236191 OrderedDict (
237192 [
@@ -245,7 +200,7 @@ def make_body(cfg: ModelCfg) -> nn.Sequential:
245200 )
246201
247202
248- def make_head (cfg : ModelCfg ) -> nn .Sequential :
203+ def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
249204 head = [
250205 ("pool" , nn .AdaptiveAvgPool2d (1 )),
251206 ("flat" , nn .Flatten ()),
@@ -254,22 +209,59 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
254209 return nn .Sequential (OrderedDict (head ))
255210
256211
212+ class ModelCfg (BaseModel ):
213+ """Model constructor Config. As default - xresnet18"""
214+
215+ name : str = "MC"
216+ in_chans : int = 3
217+ num_classes : int = 1000
218+ block : Type [nn .Module ] = ResBlock
219+ conv_layer : Type [nn .Module ] = ConvBnAct
220+ block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
221+ layers : List [int ] = [2 , 2 , 2 , 2 ]
222+ norm : Type [nn .Module ] = nn .BatchNorm2d
223+ act_fn : Type [nn .Module ] = nn .ReLU
224+ pool : Callable [[Any ], nn .Module ] = partial (nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
225+ expansion : int = 1
226+ groups : int = 1
227+ dw : bool = False
228+ div_groups : Union [int , None ] = None
229+ sa : Union [bool , int , Type [nn .Module ]] = False
230+ se : Union [bool , int , Type [nn .Module ]] = False
231+ se_module : Union [bool , None ] = None
232+ se_reduction : Union [int , None ] = None
233+ bn_1st : bool = True
234+ zero_bn : bool = True
235+ stem_stride_on : int = 0
236+ stem_sizes : List [int ] = [32 , 32 , 64 ]
237+ stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
238+ stem_bn_end : bool = False
239+ init_cnn : Callable [[nn .Module ], None ] = init_cnn
240+ make_stem : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_stem # type: ignore
241+ make_layer : Callable [[TModelCfg , int ], Union [nn .Module , nn .Sequential ]] = make_layer # type: ignore
242+ make_body : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_body # type: ignore
243+ make_head : Callable [[TModelCfg ], Union [nn .Module , nn .Sequential ]] = make_head # type: ignore
244+
245+ class Config :
246+ arbitrary_types_allowed = True
247+ extra = "forbid"
248+
249+ def extra_repr (self ) -> str :
250+ res = ""
251+ for k , v in self .dict ().items ():
252+ if v is not None :
253+ res += f"{ k } : { v } \n "
254+ return res
255+
256+ def pprint (self ) -> None :
257+ print (self .extra_repr ())
258+
259+
257260class ModelConstructor (ModelCfg ):
258261 """Model constructor. As default - xresnet18"""
259262
260263 @root_validator
261- def post_init (cls , values ):
262- if values ["init_cnn" ] is None :
263- values ["init_cnn" ] = init_cnn
264- if values ["make_stem" ] is None :
265- values ["make_stem" ] = make_stem
266- if values ["make_layer" ] is None :
267- values ["make_layer" ] = make_layer
268- if values ["make_body" ] is None :
269- values ["make_body" ] = make_body
270- if values ["make_head" ] is None :
271- values ["make_head" ] = make_head
272-
264+ def post_init (cls , values ): # pylint: disable=E0213
273265 if values ["stem_sizes" ][0 ] != values ["in_chans" ]:
274266 values ["stem_sizes" ] = [values ["in_chans" ]] + values ["stem_sizes" ]
275267 if values ["se" ] and isinstance (values ["se" ], (bool , int )): # if se=1 or se=True
@@ -284,15 +276,15 @@ def post_init(cls, values):
284276
285277 @property
286278 def stem (self ):
287- return self .make_stem (self ) # type: ignore
279+ return self .make_stem (self ) # pylint: disable=too-many-function-args
288280
289281 @property
290282 def head (self ):
291- return self .make_head (self ) # type: ignore
283+ return self .make_head (self ) # pylint: disable=too-many-function-args
292284
293285 @property
294286 def body (self ):
295- return self .make_body (self ) # type: ignore
287+ return self .make_body (self ) # pylint: disable=too-many-function-args
296288
297289 @classmethod
298290 def from_cfg (cls , cfg : ModelCfg ):
@@ -302,12 +294,12 @@ def __call__(self):
302294 model = nn .Sequential (
303295 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
304296 )
305- self .init_cnn (model ) # type: ignore
297+ self .init_cnn (model ) # pylint: disable=too-many-function-args
306298 model .extra_repr = lambda : f"{ self .name } "
307299 return model
308300
309301 def __repr__ (self ):
310- se_repr = self .se .__name__ if self .se else "False"
302+ se_repr = self .se .__name__ if self .se else "False" # type: ignore
311303 return (
312304 f"{ self .name } constructor\n "
313305 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
0 commit comments