1+ from dataclasses import dataclass , field , asdict
2+
13from collections import OrderedDict
2- from functools import partial
3- from typing import Callable , List , Type , Union
4+ # from functools import partial
5+ from typing import Callable , List , Optional , Type , Union
46
57import torch .nn as nn
68
1214 "act_fn" ,
1315 "ResBlock" ,
1416 "ModelConstructor" ,
15- "xresnet34" ,
16- "xresnet50" ,
17+ # "xresnet34",
18+ # "xresnet50",
1719]
1820
1921
2022act_fn = nn .ReLU (inplace = True )
2123
2224
23- def init_cnn (module : nn .Module ):
24- "Init module - kaiming_normal for Conv2d and 0 for biases."
25- if getattr (module , "bias" , None ) is not None :
26- nn .init .constant_ (module .bias , 0 ) # type: ignore
27- if isinstance (module , (nn .Conv2d , nn .Linear )):
28- nn .init .kaiming_normal_ (module .weight )
29- for layer in module .children ():
30- init_cnn (layer )
31-
32-
3325class ResBlock (nn .Module ):
3426 """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
3527
@@ -130,10 +122,55 @@ def forward(self, x):
130122 return self .act_fn (self .convs (x ) + identity )
131123
132124
133- def _make_stem (self ):
134- stem = [
125+ @dataclass
126+ class CfgMC :
127+ """Model constructor Config. As default - xresnet18"""
128+
129+ name : str = "MC"
130+ in_chans : int = 3
131+ num_classes : int = 1000
132+ block : Type [nn .Module ] = ResBlock
133+ conv_layer : Type [nn .Module ] = ConvBnAct
134+ block_sizes : List [int ] = field (default_factory = lambda : [64 , 128 , 256 , 512 ])
135+ layers : List [int ] = field (default_factory = lambda : [2 , 2 , 2 , 2 ])
136+ norm : Type [nn .Module ] = nn .BatchNorm2d
137+ act_fn : nn .Module = nn .ReLU (inplace = True )
138+ pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True )
139+ expansion : int = 1
140+ groups : int = 1
141+ dw : bool = False
142+ div_groups : Union [int , None ] = None
143+ sa : Union [bool , int , Type [nn .Module ]] = False
144+ se : Union [bool , int , Type [nn .Module ]] = False
145+ se_module = None
146+ se_reduction = None
147+ bn_1st : bool = True
148+ zero_bn : bool = True
149+ stem_stride_on : int = 0
150+ stem_sizes : List [int ] = field (default_factory = lambda : [32 , 32 , 64 ])
151+ stem_pool : Union [nn .Module , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
152+ stem_bn_end : bool = False
153+ _init_cnn : Optional [Callable [[nn .Module ], None ]] = field (repr = False , default = None )
154+ _make_stem : Optional [Callable ] = field (repr = False , default = None )
155+ _make_layer : Optional [Callable ] = field (repr = False , default = None )
156+ _make_body : Optional [Callable ] = field (repr = False , default = None )
157+ _make_head : Optional [Callable ] = field (repr = False , default = None )
158+
159+
160+ def init_cnn (module : nn .Module ):
161+ "Init module - kaiming_normal for Conv2d and 0 for biases."
162+ if getattr (module , "bias" , None ) is not None :
163+ nn .init .constant_ (module .bias , 0 ) # type: ignore
164+ if isinstance (module , (nn .Conv2d , nn .Linear )):
165+ nn .init .kaiming_normal_ (module .weight )
166+ for layer in module .children ():
167+ init_cnn (layer )
168+
169+
170+ def make_stem (self : CfgMC ) -> nn .Sequential :
171+ stem : List [tuple [str , nn .Module ]] = [
135172 (f"conv_{ i } " , self .conv_layer (
136- self .stem_sizes [i ],
173+ self .stem_sizes [i ], # type: ignore
137174 self .stem_sizes [i + 1 ],
138175 stride = 2 if i == self .stem_stride_on else 1 ,
139176 bn_layer = (not self .stem_bn_end )
@@ -147,39 +184,38 @@ def _make_stem(self):
147184 if self .stem_pool :
148185 stem .append (("stem_pool" , self .stem_pool ))
149186 if self .stem_bn_end :
150- stem .append (("norm" , self .norm (self .stem_sizes [- 1 ])))
187+ stem .append (("norm" , self .norm (self .stem_sizes [- 1 ]))) # type: ignore
151188 return nn .Sequential (OrderedDict (stem ))
152189
153190
154- def _make_layer ( self , layer_num : int ) -> nn .Module :
191+ def make_layer ( cfg : CfgMC , layer_num : int ) -> nn .Sequential :
155192 # expansion, in_channels, out_channels, blocks, stride, sa):
156193 # if no pool on stem - stride = 2 for first layer block in body
157- stride = 1 if self .stem_pool and layer_num == 0 else 2
158- num_blocks = self .layers [layer_num ]
194+ stride = 1 if cfg .stem_pool and layer_num == 0 else 2
195+ num_blocks = cfg .layers [layer_num ]
196+ block_chs = [cfg .stem_sizes [- 1 ] // cfg .expansion ] + cfg .block_sizes
159197 return nn .Sequential (
160198 OrderedDict (
161199 [
162200 (
163201 f"bl_{ block_num } " ,
164- self .block (
165- self .expansion ,
166- self .block_sizes [layer_num ]
167- if block_num == 0
168- else self .block_sizes [layer_num + 1 ],
169- self .block_sizes [layer_num + 1 ],
202+ cfg .block (
203+ cfg .expansion , # type: ignore
204+ block_chs [layer_num ] if block_num == 0 else block_chs [layer_num + 1 ],
205+ block_chs [layer_num + 1 ],
170206 stride if block_num == 0 else 1 ,
171- sa = self .sa
207+ sa = cfg .sa
172208 if (block_num == num_blocks - 1 ) and layer_num == 0
173209 else None ,
174- conv_layer = self .conv_layer ,
175- act_fn = self .act_fn ,
176- pool = self .pool ,
177- zero_bn = self .zero_bn ,
178- bn_1st = self .bn_1st ,
179- groups = self .groups ,
180- div_groups = self .div_groups ,
181- dw = self .dw ,
182- se = self .se ,
210+ conv_layer = cfg .conv_layer ,
211+ act_fn = cfg .act_fn ,
212+ pool = cfg .pool ,
213+ zero_bn = cfg .zero_bn ,
214+ bn_1st = cfg .bn_1st ,
215+ groups = cfg .groups ,
216+ div_groups = cfg .div_groups ,
217+ dw = cfg .dw ,
218+ se = cfg .se ,
183219 ),
184220 )
185221 for block_num in range (num_blocks )
@@ -188,160 +224,96 @@ def _make_layer(self, layer_num: int) -> nn.Module:
188224 )
189225
190226
191- def _make_body ( self ) :
227+ def make_body ( cfg : CfgMC ) -> nn . Sequential :
192228 return nn .Sequential (
193229 OrderedDict (
194230 [
195231 (
196232 f"l_{ layer_num } " ,
197- self ._make_layer (self , layer_num )
233+ cfg ._make_layer (cfg , layer_num ) # type: ignore
198234 )
199- for layer_num in range (len (self .layers ))
235+ for layer_num in range (len (cfg .layers ))
200236 ]
201237 )
202238 )
203239
204240
205- def _make_head ( self ) :
241+ def make_head ( cfg : CfgMC ) -> nn . Sequential :
206242 head = [
207243 ("pool" , nn .AdaptiveAvgPool2d (1 )),
208244 ("flat" , nn .Flatten ()),
209- ("fc" , nn .Linear (self .block_sizes [- 1 ] * self .expansion , self .num_classes )),
245+ ("fc" , nn .Linear (cfg .block_sizes [- 1 ] * cfg .expansion , cfg .num_classes )),
210246 ]
211247 return nn .Sequential (OrderedDict (head ))
212248
213249
214- class ModelConstructor :
250+ @dataclass
251+ class ModelConstructor (CfgMC ):
215252 """Model constructor. As default - xresnet18"""
216253
217- def __init__ (
218- self ,
219- name : str = "MC" ,
220- in_chans : int = 3 ,
221- num_classes : int = 1000 ,
222- block = ResBlock ,
223- conv_layer = ConvBnAct ,
224- block_sizes : List [int ] = [64 , 128 , 256 , 512 ],
225- layers : List [int ] = [2 , 2 , 2 , 2 ],
226- norm : Type [nn .Module ] = nn .BatchNorm2d ,
227- act_fn : nn .Module = nn .ReLU (inplace = True ),
228- pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True ),
229- expansion : int = 1 ,
230- groups : int = 1 ,
231- dw : bool = False ,
232- div_groups : Union [int , None ] = None ,
233- sa : Union [bool , int , Type [nn .Module ]] = False ,
234- se : Union [bool , int , Type [nn .Module ]] = False ,
235- se_module = None ,
236- se_reduction = None ,
237- bn_1st : bool = True ,
238- zero_bn : bool = True ,
239- stem_stride_on : int = 0 ,
240- stem_sizes : List [int ] = [32 , 32 , 64 ],
241- stem_pool : Union [Type [nn .Module ], None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # type: ignore
242- stem_bn_end : bool = False ,
243- _init_cnn : Callable = init_cnn ,
244- _make_stem : Callable = _make_stem ,
245- _make_layer : Callable = _make_layer ,
246- _make_body : Callable = _make_body ,
247- _make_head : Callable = _make_head ,
248- ):
249- super ().__init__ ()
250- # se can be bool, int (0, 1) or nn.Module
251- # se_module - deprecated. Leaved for warning and checks.
252- # if stem_pool is False - no pool at stem
253-
254- self .name = name
255- self .in_chans = in_chans
256- self .num_classes = num_classes
257- self .block = block
258- self .conv_layer = conv_layer
259- self ._block_sizes = block_sizes
260- self .layers = layers
261- self .norm = norm
262- self .act_fn = act_fn
263- self .pool = pool
264- self .expansion = expansion
265- self .groups = groups
266- self .dw = dw
267- self .div_groups = div_groups
268- # se_module
269- # se_reduction
270- self .bn_1st = bn_1st
271- self .zero_bn = zero_bn
272- self .stem_stride_on = stem_stride_on
273- self .stem_pool = stem_pool
274- self .stem_bn_end = stem_bn_end
275- self ._init_cnn = _init_cnn
276- self ._make_stem = _make_stem
277- self ._make_layer = _make_layer
278- self ._make_body = _make_body
279- self ._make_head = _make_head
280-
281- # params = locals()
282- # del params['self']
283- # self.__dict__ = params
284-
285- # self._block_sizes = params['block_sizes']
286- self .stem_sizes = stem_sizes
254+ def __post_init__ (self ):
255+ if self ._init_cnn is None :
256+ self ._init_cnn = init_cnn
257+ if self ._make_stem is None :
258+ self ._make_stem = make_stem
259+ if self ._make_layer is None :
260+ self ._make_layer = make_layer
261+ if self ._make_body is None :
262+ self ._make_body = make_body
263+ if self ._make_head is None :
264+ self ._make_head = make_head
265+
287266 if self .stem_sizes [0 ] != self .in_chans :
288267 self .stem_sizes = [self .in_chans ] + self .stem_sizes
289- self .se = se
290- if self .se :
291- if type (self .se ) in (bool , int ): # if se=1 or se=True
292- self .se = SEModule
293- else :
294- self .se = se # TODO add check issubclass or isinstance of nn.Module
295- self .sa = sa
296- if self .sa : # if sa=1 or sa=True
297- if type (self .sa ) in (bool , int ):
298- self .sa = SimpleSelfAttention # default: ks=1, sym=sym
299- else :
300- self .sa = sa
301- if se_module or se_reduction : # pragma: no cover
268+ if self .se and isinstance (self .se , (bool , int )): # if se=1 or se=True
269+ self .se = SEModule
270+ if self .sa and isinstance (self .sa , (bool , int )): # if sa=1 or sa=True
271+ self .sa = SimpleSelfAttention # default: ks=1, sym=sym
272+ if self .se_module or self .se_reduction : # pragma: no cover
302273 print (
303274 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304275 ) # add deprecation warning.
305276
306- @property
307- def block_sizes (self ):
308- return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes
309-
310277 @property
311278 def stem (self ):
312- return self ._make_stem (self )
279+ return self ._make_stem (self ) # type: ignore
313280
314281 @property
315282 def head (self ):
316- return self ._make_head (self )
283+ return self ._make_head (self ) # type: ignore
317284
318285 @property
319286 def body (self ):
320- return self ._make_body (self )
287+ return self ._make_body (self ) # type: ignore
288+
289+ @classmethod
290+ def from_cfg (cls , cfg : CfgMC ):
291+ return cls (** asdict (cfg ))
321292
322293 def __call__ (self ):
323294 model = nn .Sequential (
324295 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
325296 )
326- self ._init_cnn (model )
297+ self ._init_cnn (model ) # type: ignore
327298 model .extra_repr = lambda : f"{ self .name } "
328299 return model
329300
330- def __repr__ (self ):
331- return (
301+ def print_cfg (self ):
302+ print (
332303 f"{ self .name } constructor\n "
333304 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
334305 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
335306 f" sa: { self .sa } , se: { self .se } \n "
336307 f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
337- f" body sizes { self ._block_sizes } \n "
308+ f" body sizes { self .block_sizes } \n "
338309 f" layers: { self .layers } "
339310 )
340311
341312
342- xresnet34 = partial (
343- ModelConstructor , name = "xresnet34" , expansion = 1 , layers = [3 , 4 , 6 , 3 ]
313+ xresnet34 = ModelConstructor . from_cfg (
314+ CfgMC ( name = "xresnet34" , expansion = 1 , layers = [3 , 4 , 6 , 3 ])
344315)
345- xresnet50 = partial (
346- ModelConstructor , name = "xresnet34" , expansion = 4 , layers = [3 , 4 , 6 , 3 ]
316+
317+ xresnet50 = ModelConstructor .from_cfg (
318+ CfgMC (name = "xresnet34" , expansion = 4 , layers = [3 , 4 , 6 , 3 ])
347319)
0 commit comments