1- import torch .nn as nn
2- import torch
3- from torch .nn .utils import spectral_norm
41from collections import OrderedDict
2+ from typing import List , Optional , Union
53
4+ import torch
5+ import torch .nn as nn
6+ from torch .nn .utils .spectral_norm import spectral_norm
67
7- __all__ = ['Flatten' , 'noop' , 'Noop' , 'ConvLayer' , 'act_fn' ,
8- 'conv1d' , 'SimpleSelfAttention' , 'SEBlock' , 'SEBlockConv' ]
8+ __all__ = [
9+ "Flatten" ,
10+ "noop" ,
11+ "Noop" ,
12+ "ConvLayer" ,
13+ "act_fn" ,
14+ "conv1d" ,
15+ "SimpleSelfAttention" ,
16+ "SEBlock" ,
17+ "SEBlockConv" ,
18+ ]
919
1020
1121class Flatten (nn .Module ):
12- '''flat x to vector'''
22+ """flat x to vector"""
23+
1324 def __init__ (self ):
1425 super ().__init__ ()
1526
@@ -18,12 +29,13 @@ def forward(self, x):
1829
1930
2031def noop (x ):
21- ''' Dummy func. Return input'''
32+ """ Dummy func. Return input"""
2233 return x
2334
2435
2536class Noop (nn .Module ):
26- '''Dummy module'''
37+ """Dummy module"""
38+
2739 def __init__ (self ):
2840 super ().__init__ ()
2941
@@ -36,83 +48,128 @@ def forward(self, x):
3648
3749class ConvBnAct (nn .Sequential ):
3850 """Basic Conv + Bn + Act block"""
51+
3952 convolution_module = nn .Conv2d # can be changed in models like twist.
4053 batchnorm_module = nn .BatchNorm2d
4154
42- def __init__ (self , in_channels , out_channels , kernel_size = 3 , stride = 1 ,
43- padding = None , bias = False , groups = 1 ,
44- act_fn = act_fn , pre_act = False ,
45- bn_layer = True , bn_1st = True , zero_bn = False ,
46- ):
55+ def __init__ (
56+ self ,
57+ in_channels : int ,
58+ out_channels : int ,
59+ kernel_size : int = 3 ,
60+ stride : int = 1 ,
61+ padding : Optional [int ] = None ,
62+ bias : bool = False ,
63+ groups : int = 1 ,
64+ act_fn : Union [nn .Module , bool ] = act_fn ,
65+ pre_act : bool = False ,
66+ bn_layer : bool = True ,
67+ bn_1st : bool = True ,
68+ zero_bn : bool = False ,
69+ ):
4770
4871 if padding is None :
4972 padding = kernel_size // 2
50- layers = [('conv' , self .convolution_module (in_channels , out_channels , kernel_size , stride = stride ,
51- padding = padding , bias = bias , groups = groups ))] # if no bn - bias True?
73+ layers : List [tuple [str , nn .Module ]] = [
74+ (
75+ "conv" ,
76+ self .convolution_module (
77+ in_channels ,
78+ out_channels ,
79+ kernel_size ,
80+ stride = stride ,
81+ padding = padding ,
82+ bias = bias ,
83+ groups = groups ,
84+ ),
85+ )
86+ ] # if no bn - bias True?
5287 if bn_layer :
5388 bn = self .batchnorm_module (out_channels )
54- nn .init .constant_ (bn .weight , 0. if zero_bn else 1. )
55- layers .append (('bn' , bn ))
56- if act_fn :
89+ nn .init .constant_ (bn .weight , 0.0 if zero_bn else 1.0 )
90+ layers .append (("bn" , bn ))
91+ if isinstance ( act_fn , nn . Module ): # act_fn either nn.Module or False
5792 if pre_act :
5893 act_position = 0
5994 elif not bn_1st :
6095 act_position = 1
6196 else :
6297 act_position = len (layers )
63- layers .insert (act_position , (' act_fn' , act_fn ))
98+ layers .insert (act_position , (" act_fn" , act_fn ))
6499 super ().__init__ (OrderedDict (layers ))
65100
66101
67102# NOTE First version. Leaved for backwards compatibility with old blocks, models.
68103class ConvLayer (nn .Sequential ):
69104 """Basic conv layers block"""
105+
70106 Conv2d = nn .Conv2d
71107
72- def __init__ (self , ni , nf , ks = 3 , stride = 1 ,
73- act = True , act_fn = act_fn ,
74- bn_layer = True , bn_1st = True , zero_bn = False ,
75- padding = None , bias = False , groups = 1 , ** kwargs ):
108+ def __init__ (
109+ self ,
110+ ni ,
111+ nf ,
112+ ks = 3 ,
113+ stride = 1 ,
114+ act = True ,
115+ act_fn = act_fn ,
116+ bn_layer = True ,
117+ bn_1st = True ,
118+ zero_bn = False ,
119+ padding = None ,
120+ bias = False ,
121+ groups = 1 ,
122+ ** kwargs
123+ ):
76124
77125 if padding is None :
78126 padding = ks // 2
79- layers = [('conv' , self .Conv2d (ni , nf , ks , stride = stride ,
80- padding = padding , bias = bias , groups = groups ))]
81- act_bn = [('act_fn' , act_fn )] if act else []
127+ layers = [
128+ (
129+ "conv" ,
130+ self .Conv2d (
131+ ni , nf , ks , stride = stride , padding = padding , bias = bias , groups = groups
132+ ),
133+ )
134+ ]
135+ act_bn = [("act_fn" , act_fn )] if act else []
82136 if bn_layer :
83137 bn = nn .BatchNorm2d (nf )
84- nn .init .constant_ (bn .weight , 0. if zero_bn else 1. )
85- act_bn += [('bn' , bn )]
138+ nn .init .constant_ (bn .weight , 0.0 if zero_bn else 1.0 )
139+ act_bn += [("bn" , bn )]
86140 if bn_1st :
87141 act_bn .reverse ()
88142 layers += act_bn
89143 super ().__init__ (OrderedDict (layers ))
90144
145+
91146# Cell
92147# SA module from mxresnet at fastai. todo - add persons!!!
93148# Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
94149
95150
96- def conv1d (ni : int , no : int , ks : int = 1 , stride : int = 1 , padding : int = 0 , bias : bool = False ):
151+ def conv1d (
152+ ni : int , no : int , ks : int = 1 , stride : int = 1 , padding : int = 0 , bias : bool = False
153+ ):
97154 "Create and initialize a `nn.Conv1d` layer with spectral normalization."
98155 conv = nn .Conv1d (ni , no , ks , stride = stride , padding = padding , bias = bias )
99156 nn .init .kaiming_normal_ (conv .weight )
100157 if bias :
101- conv .bias .data .zero_ ()
158+ conv .bias .data .zero_ () # type: ignore
102159 return spectral_norm (conv )
103160
104161
105162class SimpleSelfAttention (nn .Module ):
106- ''' SimpleSelfAttention module. # noqa W291
107- Adapted from SelfAttention layer at
108- https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
109- Inspired by https://arxiv.org/pdf/1805.08318.pdf
110- '''
163+ """ SimpleSelfAttention module. # noqa W291
164+ Adapted from SelfAttention layer at
165+ https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
166+ Inspired by https://arxiv.org/pdf/1805.08318.pdf
167+ """
111168
112169 def __init__ (self , n_in : int , ks = 1 , sym = False , use_bias = False ):
113170 super ().__init__ ()
114171 self .conv = conv1d (n_in , n_in , ks , padding = ks // 2 , bias = use_bias )
115- self .gamma = nn .Parameter (torch .tensor ([0. ]))
172+ self .gamma = torch . nn .Parameter (torch .tensor ([0.0 ])) # type: ignore
116173 self .sym = sym
117174 self .n_in = n_in
118175
@@ -123,17 +180,19 @@ def forward(self, x):
123180 c = (c + c .t ()) / 2
124181 self .conv .weight = c .view (self .n_in , self .n_in , 1 )
125182 size = x .size ()
126- x = x .view (* size [:2 ], - 1 ) # (C,N)
183+ x = x .view (* size [:2 ], - 1 ) # (C,N)
127184 # changed the order of multiplication to avoid O(N^2) complexity
128185 # (x*xT)*(W*x) instead of (x*(xT*(W*x)))
129- convx = self .conv (x ) # (C,C) * (C,N) = (C,N) => O(NC^2)
130- xxT = torch .bmm (x , x .permute (0 , 2 , 1 ).contiguous ()) # (C,N) * (N,C) = (C,C) => O(NC^2)
131- o = torch .bmm (xxT , convx ) # (C,C) * (C,N) = (C,N) => O(NC^2)
186+ convx = self .conv (x ) # (C,C) * (C,N) = (C,N) => O(NC^2)
187+ xxT = torch .bmm (
188+ x , x .permute (0 , 2 , 1 ).contiguous ()
189+ ) # (C,N) * (N,C) = (C,C) => O(NC^2)
190+ o = torch .bmm (xxT , convx ) # (C,C) * (C,N) = (C,N) => O(NC^2)
132191 o = self .gamma * o + x
133192 return o .view (* size ).contiguous ()
134193
135194
136- class SEBlock (nn .Module ): # todo: deprecation worning .
195+ class SEBlock (nn .Module ): # todo: deprecation warning .
137196 "se block"
138197 se_layer = nn .Linear
139198 act_fn = nn .ReLU (inplace = True )
@@ -144,11 +203,15 @@ def __init__(self, c, r=16):
144203 ch = max (c // r , 1 )
145204 self .squeeze = nn .AdaptiveAvgPool2d (1 )
146205 self .excitation = nn .Sequential (
147- OrderedDict ([('fc_reduce' , self .se_layer (c , ch , bias = self .use_bias )),
148- ('se_act' , self .act_fn ),
149- ('fc_expand' , self .se_layer (ch , c , bias = self .use_bias )),
150- ('sigmoid' , nn .Sigmoid ())
151- ]))
206+ OrderedDict (
207+ [
208+ ("fc_reduce" , self .se_layer (c , ch , bias = self .use_bias )),
209+ ("se_act" , self .act_fn ),
210+ ("fc_expand" , self .se_layer (ch , c , bias = self .use_bias )),
211+ ("sigmoid" , nn .Sigmoid ()),
212+ ]
213+ )
214+ )
152215
153216 def forward (self , x ):
154217 bs , c , _ , _ = x .shape
@@ -157,24 +220,27 @@ def forward(self, x):
157220 return x * y .expand_as (x )
158221
159222
160- class SEBlockConv (nn .Module ): # todo: deprecation worning .
223+ class SEBlockConv (nn .Module ): # todo: deprecation warning .
161224 "se block with conv on excitation"
162225 se_layer = nn .Conv2d
163226 act_fn = nn .ReLU (inplace = True )
164227 use_bias = True
165228
166229 def __init__ (self , c , r = 16 ):
167230 super ().__init__ ()
168- # c_in = math.ceil(c//r/8)*8
231+ # c_in = math.ceil(c//r/8)*8
169232 c_in = max (c // r , 1 )
170233 self .squeeze = nn .AdaptiveAvgPool2d (1 )
171234 self .excitation = nn .Sequential (
172- OrderedDict ([
173- ('conv_reduce' , self .se_layer (c , c_in , 1 , bias = self .use_bias )),
174- ('se_act' , self .act_fn ),
175- ('conv_expand' , self .se_layer (c_in , c , 1 , bias = self .use_bias )),
176- ('sigmoid' , nn .Sigmoid ())
177- ]))
235+ OrderedDict (
236+ [
237+ ("conv_reduce" , self .se_layer (c , c_in , 1 , bias = self .use_bias )),
238+ ("se_act" , self .act_fn ),
239+ ("conv_expand" , self .se_layer (c_in , c , 1 , bias = self .use_bias )),
240+ ("sigmoid" , nn .Sigmoid ()),
241+ ]
242+ )
243+ )
178244
179245 def forward (self , x ):
180246 y = self .squeeze (x )
@@ -185,16 +251,17 @@ def forward(self, x):
185251class SEModule (nn .Module ):
186252 "se block"
187253
188- def __init__ (self ,
189- channels ,
190- reduction = 16 ,
191- rd_channels = None ,
192- rd_max = False ,
193- se_layer = nn .Linear ,
194- act_fn = nn .ReLU (inplace = True ),
195- use_bias = True ,
196- gate = nn .Sigmoid
197- ):
254+ def __init__ (
255+ self ,
256+ channels ,
257+ reduction = 16 ,
258+ rd_channels = None ,
259+ rd_max = False ,
260+ se_layer = nn .Linear ,
261+ act_fn = nn .ReLU (inplace = True ),
262+ use_bias = True ,
263+ gate = nn .Sigmoid ,
264+ ):
198265 super ().__init__ ()
199266 reducted = max (channels // reduction , 1 ) # preserve zero-element tensors
200267 if rd_channels is None :
@@ -204,11 +271,15 @@ def __init__(self,
204271 rd_channels = max (rd_channels , reducted )
205272 self .squeeze = nn .AdaptiveAvgPool2d (1 )
206273 self .excitation = nn .Sequential (
207- OrderedDict ([('reduce' , se_layer (channels , rd_channels , bias = use_bias )),
208- ('se_act' , act_fn ),
209- ('expand' , se_layer (rd_channels , channels , bias = use_bias )),
210- ('se_gate' , gate ())
211- ]))
274+ OrderedDict (
275+ [
276+ ("reduce" , se_layer (channels , rd_channels , bias = use_bias )),
277+ ("se_act" , act_fn ),
278+ ("expand" , se_layer (rd_channels , channels , bias = use_bias )),
279+ ("se_gate" , gate ()),
280+ ]
281+ )
282+ )
212283
213284 def forward (self , x ):
214285 bs , c , _ , _ = x .shape
@@ -220,18 +291,19 @@ def forward(self, x):
220291class SEModuleConv (nn .Module ):
221292 "se block with conv on excitation"
222293
223- def __init__ (self ,
224- channels ,
225- reduction = 16 ,
226- rd_channels = None ,
227- rd_max = False ,
228- se_layer = nn .Conv2d ,
229- act_fn = nn .ReLU (inplace = True ),
230- use_bias = True ,
231- gate = nn .Sigmoid
232- ):
294+ def __init__ (
295+ self ,
296+ channels ,
297+ reduction = 16 ,
298+ rd_channels = None ,
299+ rd_max = False ,
300+ se_layer = nn .Conv2d ,
301+ act_fn = nn .ReLU (inplace = True ),
302+ use_bias = True ,
303+ gate = nn .Sigmoid ,
304+ ):
233305 super ().__init__ ()
234- # rd_channels = math.ceil(channels//reduction/8)*8
306+ # rd_channels = math.ceil(channels//reduction/8)*8
235307 reducted = max (channels // reduction , 1 ) # preserve zero-element tensors
236308 if rd_channels is None :
237309 rd_channels = reducted
@@ -240,12 +312,15 @@ def __init__(self,
240312 rd_channels = max (rd_channels , reducted )
241313 self .squeeze = nn .AdaptiveAvgPool2d (1 )
242314 self .excitation = nn .Sequential (
243- OrderedDict ([
244- ('reduce' , se_layer (channels , rd_channels , 1 , bias = use_bias )),
245- ('se_act' , act_fn ),
246- ('expand' , se_layer (rd_channels , channels , 1 , bias = use_bias )),
247- ('gate' , gate ())
248- ]))
315+ OrderedDict (
316+ [
317+ ("reduce" , se_layer (channels , rd_channels , 1 , bias = use_bias )),
318+ ("se_act" , act_fn ),
319+ ("expand" , se_layer (rd_channels , channels , 1 , bias = use_bias )),
320+ ("gate" , gate ()),
321+ ]
322+ )
323+ )
249324
250325 def forward (self , x ):
251326 y = self .squeeze (x )
0 commit comments