77from .layers import ConvBnAct , SEModule , SimpleSelfAttention
88
99
10- __all__ = ['init_cnn' , 'act_fn' , 'ResBlock' , 'ModelConstructor' , 'xresnet34' , 'xresnet50' ]
10+ __all__ = [
11+ "init_cnn" ,
12+ "act_fn" ,
13+ "ResBlock" ,
14+ "ModelConstructor" ,
15+ "xresnet34" ,
16+ "xresnet50" ,
17+ ]
1118
1219
1320act_fn = nn .ReLU (inplace = True )
1421
1522
1623def init_cnn (module : nn .Module ):
1724 "Init module - kaiming_normal for Conv2d and 0 for biases."
18- if getattr (module , ' bias' , None ) is not None :
25+ if getattr (module , " bias" , None ) is not None :
1926 nn .init .constant_ (module .bias , 0 ) # type: ignore
2027 if isinstance (module , (nn .Conv2d , nn .Linear )):
2128 nn .init .kaiming_normal_ (module .weight )
@@ -24,7 +31,7 @@ def init_cnn(module: nn.Module):
2431
2532
2633class ResBlock (nn .Module ):
27- ''' Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.'''
34+ """ Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2835
2936 def __init__ (
3037 self ,
@@ -49,31 +56,70 @@ def __init__(
4956 if div_groups is not None : # check if groups != 1 and div_groups
5057 groups = int (mid_channels / div_groups )
5158 if expansion == 1 :
52- layers = [("conv_0" , conv_layer (in_channels , mid_channels , 3 , stride = stride , # type: ignore
53- act_fn = act_fn , bn_1st = bn_1st , groups = in_channels if dw else groups )),
54- ("conv_1" , conv_layer (mid_channels , out_channels , 3 , zero_bn = zero_bn ,
55- act_fn = False , bn_1st = bn_1st , groups = mid_channels if dw else groups ))
56- ]
59+ layers = [
60+ ("conv_0" , conv_layer (
61+ in_channels ,
62+ mid_channels ,
63+ 3 ,
64+ stride = stride , # type: ignore
65+ act_fn = act_fn ,
66+ bn_1st = bn_1st ,
67+ groups = in_channels if dw else groups ,
68+ ),),
69+ ("conv_1" , conv_layer (
70+ mid_channels ,
71+ out_channels ,
72+ 3 ,
73+ zero_bn = zero_bn ,
74+ act_fn = False ,
75+ bn_1st = bn_1st ,
76+ groups = mid_channels if dw else groups ,
77+ ),),
78+ ]
5779 else :
58- layers = [("conv_0" , conv_layer (in_channels , mid_channels , 1 , act_fn = act_fn , bn_1st = bn_1st )),
59- ("conv_1" , conv_layer (mid_channels , mid_channels , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st ,
60- groups = mid_channels if dw else groups )),
61- ("conv_2" , conv_layer (mid_channels , out_channels , 1 , zero_bn = zero_bn , act_fn = False , bn_1st = bn_1st )) # noqa E501
62- ]
80+ layers = [
81+ ("conv_0" , conv_layer (
82+ in_channels ,
83+ mid_channels ,
84+ 1 ,
85+ act_fn = act_fn ,
86+ bn_1st = bn_1st ,
87+ ),),
88+ ("conv_1" , conv_layer (
89+ mid_channels ,
90+ mid_channels ,
91+ 3 ,
92+ stride = stride ,
93+ act_fn = act_fn ,
94+ bn_1st = bn_1st ,
95+ groups = mid_channels if dw else groups ,
96+ ),),
97+ ("conv_2" , conv_layer (
98+ mid_channels ,
99+ out_channels ,
100+ 1 ,
101+ zero_bn = zero_bn ,
102+ act_fn = False ,
103+ bn_1st = bn_1st ,
104+ ),), # noqa E501
105+ ]
63106 if se :
64- layers .append (('se' , se (out_channels )))
107+ layers .append (("se" , se (out_channels )))
65108 if sa :
66- layers .append (('sa' , sa (out_channels )))
109+ layers .append (("sa" , sa (out_channels )))
67110 self .convs = nn .Sequential (OrderedDict (layers ))
68111 if stride != 1 or in_channels != out_channels :
69112 id_layers = []
70113 if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
71114 id_layers .append (("pool" , pool ))
72115 if in_channels != out_channels or (stride != 1 and pool is None ):
73116 id_layers += [("id_conv" , conv_layer (
74- in_channels , out_channels , 1 ,
117+ in_channels ,
118+ out_channels ,
119+ 1 ,
75120 stride = 1 if pool else stride ,
76- act_fn = False ))]
121+ act_fn = False ,
122+ ),)]
77123 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
78124 else :
79125 self .id_conv = None
@@ -85,15 +131,23 @@ def forward(self, x):
85131
86132
87133def _make_stem (self ):
88- stem = [(f"conv_{ i } " , self .conv_layer (self .stem_sizes [i ], self .stem_sizes [i + 1 ],
89- stride = 2 if i == self .stem_stride_on else 1 ,
90- bn_layer = (not self .stem_bn_end ) if i == (len (self .stem_sizes ) - 2 ) else True ,
91- act_fn = self .act_fn , bn_1st = self .bn_1st ))
92- for i in range (len (self .stem_sizes ) - 1 )]
134+ stem = [
135+ (f"conv_{ i } " , self .conv_layer (
136+ self .stem_sizes [i ],
137+ self .stem_sizes [i + 1 ],
138+ stride = 2 if i == self .stem_stride_on else 1 ,
139+ bn_layer = (not self .stem_bn_end )
140+ if i == (len (self .stem_sizes ) - 2 )
141+ else True ,
142+ act_fn = self .act_fn ,
143+ bn_1st = self .bn_1st ,
144+ ),)
145+ for i in range (len (self .stem_sizes ) - 1 )
146+ ]
93147 if self .stem_pool :
94- stem .append ((' stem_pool' , self .stem_pool ))
148+ stem .append ((" stem_pool" , self .stem_pool ))
95149 if self .stem_bn_end :
96- stem .append ((' norm' , self .norm (self .stem_sizes [- 1 ])))
150+ stem .append ((" norm" , self .norm (self .stem_sizes [- 1 ])))
97151 return nn .Sequential (OrderedDict (stem ))
98152
99153
@@ -102,43 +156,67 @@ def _make_layer(self, layer_num: int) -> nn.Module:
102156 # if no pool on stem - stride = 2 for first layer block in body
103157 stride = 1 if self .stem_pool and layer_num == 0 else 2
104158 num_blocks = self .layers [layer_num ]
105- return nn .Sequential (OrderedDict ([
106- (f"bl_{ block_num } " , self .block (
107- self .expansion ,
108- self .block_sizes [layer_num ] if block_num == 0 else self .block_sizes [layer_num + 1 ],
109- self .block_sizes [layer_num + 1 ],
110- stride if block_num == 0 else 1 ,
111- sa = self .sa if (block_num == num_blocks - 1 ) and layer_num == 0 else None ,
112- conv_layer = self .conv_layer ,
113- act_fn = self .act_fn ,
114- pool = self .pool ,
115- zero_bn = self .zero_bn , bn_1st = self .bn_1st ,
116- groups = self .groups , div_groups = self .div_groups ,
117- dw = self .dw , se = self .se
118- ))
119- for block_num in range (num_blocks )
120- ]))
159+ return nn .Sequential (
160+ OrderedDict (
161+ [
162+ (
163+ f"bl_{ block_num } " ,
164+ self .block (
165+ self .expansion ,
166+ self .block_sizes [layer_num ]
167+ if block_num == 0
168+ else self .block_sizes [layer_num + 1 ],
169+ self .block_sizes [layer_num + 1 ],
170+ stride if block_num == 0 else 1 ,
171+ sa = self .sa
172+ if (block_num == num_blocks - 1 ) and layer_num == 0
173+ else None ,
174+ conv_layer = self .conv_layer ,
175+ act_fn = self .act_fn ,
176+ pool = self .pool ,
177+ zero_bn = self .zero_bn ,
178+ bn_1st = self .bn_1st ,
179+ groups = self .groups ,
180+ div_groups = self .div_groups ,
181+ dw = self .dw ,
182+ se = self .se ,
183+ ),
184+ )
185+ for block_num in range (num_blocks )
186+ ]
187+ )
188+ )
121189
122190
123191def _make_body (self ):
124- return nn .Sequential (OrderedDict ([
125- (f"l_{ layer_num } " , self ._make_layer (self , layer_num ))
126- for layer_num in range (len (self .layers ))
127- ]))
192+ return nn .Sequential (
193+ OrderedDict (
194+ [
195+ (
196+ f"l_{ layer_num } " ,
197+ self ._make_layer (self , layer_num )
198+ )
199+ for layer_num in range (len (self .layers ))
200+ ]
201+ )
202+ )
128203
129204
130205def _make_head (self ):
131- head = [('pool' , nn .AdaptiveAvgPool2d (1 )),
132- ('flat' , nn .Flatten ()),
133- ('fc' , nn .Linear (self .block_sizes [- 1 ] * self .expansion , self .num_classes ))]
206+ head = [
207+ ("pool" , nn .AdaptiveAvgPool2d (1 )),
208+ ("flat" , nn .Flatten ()),
209+ ("fc" , nn .Linear (self .block_sizes [- 1 ] * self .expansion , self .num_classes )),
210+ ]
134211 return nn .Sequential (OrderedDict (head ))
135212
136213
137- class ModelConstructor () :
214+ class ModelConstructor :
138215 """Model constructor. As default - xresnet18"""
216+
139217 def __init__ (
140218 self ,
141- name : str = 'MC' ,
219+ name : str = "MC" ,
142220 in_chans : int = 3 ,
143221 num_classes : int = 1000 ,
144222 block = ResBlock ,
@@ -221,7 +299,9 @@ def __init__(
221299 else :
222300 self .sa = sa
223301 if se_module or se_reduction : # pragma: no cover
224- print ("Deprecated. Pass se_module as se argument, se_reduction as arg to se." ) # add deprecation warning.
302+ print (
303+ "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304+ ) # add deprecation warning.
225305
226306 @property
227307 def block_sizes (self ):
@@ -240,23 +320,28 @@ def body(self):
240320 return self ._make_body (self )
241321
242322 def __call__ (self ):
243- model = nn .Sequential (OrderedDict ([
244- ('stem' , self .stem ),
245- ('body' , self .body ),
246- ('head' , self .head )]))
323+ model = nn .Sequential (
324+ OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
325+ )
247326 self ._init_cnn (model )
248327 model .extra_repr = lambda : f"{ self .name } "
249328 return model
250329
251330 def __repr__ (self ):
252- return (f"{ self .name } constructor\n "
253- f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
254- f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
255- f" sa: { self .sa } , se: { self .se } \n "
256- f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
257- f" body sizes { self ._block_sizes } \n "
258- f" layers: { self .layers } " )
259-
260-
261- xresnet34 = partial (ModelConstructor , name = 'xresnet34' , expansion = 1 , layers = [3 , 4 , 6 , 3 ])
262- xresnet50 = partial (ModelConstructor , name = 'xresnet34' , expansion = 4 , layers = [3 , 4 , 6 , 3 ])
331+ return (
332+ f"{ self .name } constructor\n "
333+ f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
334+ f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
335+ f" sa: { self .sa } , se: { self .se } \n "
336+ f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
337+ f" body sizes { self ._block_sizes } \n "
338+ f" layers: { self .layers } "
339+ )
340+
341+
342+ xresnet34 = partial (
343+ ModelConstructor , name = "xresnet34" , expansion = 1 , layers = [3 , 4 , 6 , 3 ]
344+ )
345+ xresnet50 = partial (
346+ ModelConstructor , name = "xresnet34" , expansion = 4 , layers = [3 , 4 , 6 , 3 ]
347+ )
0 commit comments