1+ from dataclasses import dataclass , field , asdict
2+
13from collections import OrderedDict
2- from functools import partial
3- from typing import Callable , List , Type , Union
4+ # from functools import partial
5+ from typing import Callable , List , Optional , Type , Union
46
57import torch .nn as nn
68
1214 "act_fn" ,
1315 "ResBlock" ,
1416 "ModelConstructor" ,
15- "xresnet34" ,
16- "xresnet50" ,
17+ # "xresnet34",
18+ # "xresnet50",
1719]
1820
1921
2022act_fn = nn .ReLU (inplace = True )
2123
2224
23- def init_cnn (module : nn .Module ):
24- "Init module - kaiming_normal for Conv2d and 0 for biases."
25- if getattr (module , "bias" , None ) is not None :
26- nn .init .constant_ (module .bias , 0 ) # type: ignore
27- if isinstance (module , (nn .Conv2d , nn .Linear )):
28- nn .init .kaiming_normal_ (module .weight )
29- for layer in module .children ():
30- init_cnn (layer )
31-
32-
3325class ResBlock (nn .Module ):
3426 """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
3527
@@ -130,10 +122,55 @@ def forward(self, x):
130122 return self .act_fn (self .convs (x ) + identity )
131123
132124
133- def _make_stem (self ):
134- stem = [
125+ @dataclass
126+ class ModelConstructorCfg :
127+ """Model constructor. As default - xresnet18"""
128+
129+ name : str = "MC"
130+ in_chans : int = 3
131+ num_classes : int = 1000
132+ block : Type [nn .Module ] = ResBlock
133+ conv_layer : Type [nn .Module ] = ConvBnAct
134+ _block_sizes : List [int ] = field (default_factory = lambda : [64 , 128 , 256 , 512 ])
135+ layers : List [int ] = field (default_factory = lambda : [2 , 2 , 2 , 2 ])
136+ norm : Type [nn .Module ] = nn .BatchNorm2d
137+ act_fn : nn .Module = nn .ReLU (inplace = True )
138+ pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True )
139+ expansion : int = 1
140+ groups : int = 1
141+ dw : bool = False
142+ div_groups : Union [int , None ] = None
143+ sa : Union [bool , int , Type [nn .Module ]] = False
144+ se : Union [bool , int , Type [nn .Module ]] = False
145+ se_module = None
146+ se_reduction = None
147+ bn_1st : bool = True
148+ zero_bn : bool = True
149+ stem_stride_on : int = 0
150+ stem_sizes : List [int ] = field (default_factory = lambda : [32 , 32 , 64 ])
151+ stem_pool : Union [nn .Module , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
152+ stem_bn_end : bool = False
153+ _init_cnn : Optional [Callable [[nn .Module ], None ]] = None
154+ _make_stem : Optional [Callable ] = None
155+ _make_layer : Optional [Callable ] = None
156+ _make_body : Optional [Callable ] = None
157+ _make_head : Optional [Callable ] = None
158+
159+
160+ def init_cnn (module : nn .Module ):
161+ "Init module - kaiming_normal for Conv2d and 0 for biases."
162+ if getattr (module , "bias" , None ) is not None :
163+ nn .init .constant_ (module .bias , 0 ) # type: ignore
164+ if isinstance (module , (nn .Conv2d , nn .Linear )):
165+ nn .init .kaiming_normal_ (module .weight )
166+ for layer in module .children ():
167+ init_cnn (layer )
168+
169+
170+ def _make_stem (self : ModelConstructorCfg ) -> nn .Sequential :
171+ stem : List [tuple [str , nn .Module ]] = [
135172 (f"conv_{ i } " , self .conv_layer (
136- self .stem_sizes [i ],
173+ self .stem_sizes [i ], # type: ignore
137174 self .stem_sizes [i + 1 ],
138175 stride = 2 if i == self .stem_stride_on else 1 ,
139176 bn_layer = (not self .stem_bn_end )
@@ -147,7 +184,7 @@ def _make_stem(self):
147184 if self .stem_pool :
148185 stem .append (("stem_pool" , self .stem_pool ))
149186 if self .stem_bn_end :
150- stem .append (("norm" , self .norm (self .stem_sizes [- 1 ])))
187+ stem .append (("norm" , self .norm (self .stem_sizes [- 1 ]))) # type: ignore
151188 return nn .Sequential (OrderedDict (stem ))
152189
153190
@@ -202,7 +239,7 @@ def _make_body(self):
202239 )
203240
204241
205- def _make_head (self ) :
242+ def _make_head (self : ModelConstructorCfg ) -> nn . Sequential :
206243 head = [
207244 ("pool" , nn .AdaptiveAvgPool2d (1 )),
208245 ("flat" , nn .Flatten ()),
@@ -211,94 +248,29 @@ def _make_head(self):
211248 return nn .Sequential (OrderedDict (head ))
212249
213250
214- class ModelConstructor :
251+ @dataclass
252+ class ModelConstructor (ModelConstructorCfg ):
215253 """Model constructor. As default - xresnet18"""
216254
217- def __init__ (
218- self ,
219- name : str = "MC" ,
220- in_chans : int = 3 ,
221- num_classes : int = 1000 ,
222- block = ResBlock ,
223- conv_layer = ConvBnAct ,
224- block_sizes : List [int ] = [64 , 128 , 256 , 512 ],
225- layers : List [int ] = [2 , 2 , 2 , 2 ],
226- norm : Type [nn .Module ] = nn .BatchNorm2d ,
227- act_fn : nn .Module = nn .ReLU (inplace = True ),
228- pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True ),
229- expansion : int = 1 ,
230- groups : int = 1 ,
231- dw : bool = False ,
232- div_groups : Union [int , None ] = None ,
233- sa : Union [bool , int , Type [nn .Module ]] = False ,
234- se : Union [bool , int , Type [nn .Module ]] = False ,
235- se_module = None ,
236- se_reduction = None ,
237- bn_1st : bool = True ,
238- zero_bn : bool = True ,
239- stem_stride_on : int = 0 ,
240- stem_sizes : List [int ] = [32 , 32 , 64 ],
241- stem_pool : Union [Type [nn .Module ], None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # type: ignore
242- stem_bn_end : bool = False ,
243- _init_cnn : Callable = init_cnn ,
244- _make_stem : Callable = _make_stem ,
245- _make_layer : Callable = _make_layer ,
246- _make_body : Callable = _make_body ,
247- _make_head : Callable = _make_head ,
248- ):
249- super ().__init__ ()
250- # se can be bool, int (0, 1) or nn.Module
251- # se_module - deprecated. Leaved for warning and checks.
252- # if stem_pool is False - no pool at stem
253-
254- self .name = name
255- self .in_chans = in_chans
256- self .num_classes = num_classes
257- self .block = block
258- self .conv_layer = conv_layer
259- self ._block_sizes = block_sizes
260- self .layers = layers
261- self .norm = norm
262- self .act_fn = act_fn
263- self .pool = pool
264- self .expansion = expansion
265- self .groups = groups
266- self .dw = dw
267- self .div_groups = div_groups
268- # se_module
269- # se_reduction
270- self .bn_1st = bn_1st
271- self .zero_bn = zero_bn
272- self .stem_stride_on = stem_stride_on
273- self .stem_pool = stem_pool
274- self .stem_bn_end = stem_bn_end
275- self ._init_cnn = _init_cnn
276- self ._make_stem = _make_stem
277- self ._make_layer = _make_layer
278- self ._make_body = _make_body
279- self ._make_head = _make_head
280-
281- # params = locals()
282- # del params['self']
283- # self.__dict__ = params
284-
285- # self._block_sizes = params['block_sizes']
286- self .stem_sizes = stem_sizes
255+ def __post_init__ (self ):
256+ if self ._init_cnn is None :
257+ self ._init_cnn = init_cnn
258+ if self ._make_stem is None :
259+ self ._make_stem = _make_stem
260+ if self ._make_layer is None :
261+ self ._make_layer = _make_layer
262+ if self ._make_body is None :
263+ self ._make_body = _make_body
264+ if self ._make_head is None :
265+ self ._make_head = _make_head
266+
287267 if self .stem_sizes [0 ] != self .in_chans :
288268 self .stem_sizes = [self .in_chans ] + self .stem_sizes
289- self .se = se
290- if self .se :
291- if type (self .se ) in (bool , int ): # if se=1 or se=True
292- self .se = SEModule
293- else :
294- self .se = se # TODO add check issubclass or isinstance of nn.Module
295- self .sa = sa
296- if self .sa : # if sa=1 or sa=True
297- if type (self .sa ) in (bool , int ):
298- self .sa = SimpleSelfAttention # default: ks=1, sym=sym
299- else :
300- self .sa = sa
301- if se_module or se_reduction : # pragma: no cover
269+ if self .se and isinstance (self .se , (bool , int )): # if se=1 or se=True
270+ self .se = SEModule
271+ if self .sa and isinstance (self .sa , (bool , int )): # if sa=1 or sa=True
272+ self .sa = SimpleSelfAttention # default: ks=1, sym=sym
273+ if self .se_module or self .se_reduction : # pragma: no cover
302274 print (
303275 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304276 ) # add deprecation warning.
@@ -319,6 +291,10 @@ def head(self):
319291 def body (self ):
320292 return self ._make_body (self )
321293
294+ @classmethod
295+ def from_cfg (cls , cfg : ModelConstructorCfg ):
296+ return cls (** asdict (cfg ))
297+
322298 def __call__ (self ):
323299 model = nn .Sequential (
324300 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
@@ -339,9 +315,9 @@ def __repr__(self):
339315 )
340316
341317
342- xresnet34 = partial (
343- ModelConstructor , name = "xresnet34" , expansion = 1 , layers = [3 , 4 , 6 , 3 ]
344- )
345- xresnet50 = partial (
346- ModelConstructor , name = "xresnet34" , expansion = 4 , layers = [3 , 4 , 6 , 3 ]
347- )
318+ # xresnet34 = partial(
319+ # ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
320+ # )
321+ # xresnet50 = partial(
322+ # ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
323+ # )
0 commit comments