Skip to content

Commit d8af738

Browse files
committed
black layers
1 parent 76ae20d commit d8af738

File tree

1 file changed

+136
-74
lines changed

1 file changed

+136
-74
lines changed

model_constructor/layers.py

Lines changed: 136 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@
55
import torch.nn as nn
66
from torch.nn.utils.spectral_norm import spectral_norm
77

8-
__all__ = ['Flatten', 'noop', 'Noop', 'ConvLayer', 'act_fn',
9-
'conv1d', 'SimpleSelfAttention', 'SEBlock', 'SEBlockConv']
8+
__all__ = [
9+
"Flatten",
10+
"noop",
11+
"Noop",
12+
"ConvLayer",
13+
"act_fn",
14+
"conv1d",
15+
"SimpleSelfAttention",
16+
"SEBlock",
17+
"SEBlockConv",
18+
]
1019

1120

1221
class Flatten(nn.Module):
13-
'''flat x to vector'''
22+
"""flat x to vector"""
23+
1424
def __init__(self):
1525
super().__init__()
1626

@@ -19,12 +29,13 @@ def forward(self, x):
1929

2030

2131
def noop(x):
22-
'''Dummy func. Return input'''
32+
"""Dummy func. Return input"""
2333
return x
2434

2535

2636
class Noop(nn.Module):
27-
'''Dummy module'''
37+
"""Dummy module"""
38+
2839
def __init__(self):
2940
super().__init__()
3041

@@ -37,6 +48,7 @@ def forward(self, x):
3748

3849
class ConvBnAct(nn.Sequential):
3950
"""Basic Conv + Bn + Act block"""
51+
4052
convolution_module = nn.Conv2d # can be changed in models like twist.
4153
batchnorm_module = nn.BatchNorm2d
4254

@@ -59,54 +71,86 @@ def __init__(
5971
if padding is None:
6072
padding = kernel_size // 2
6173
layers: List[tuple[str, nn.Module]] = [
62-
('conv', self.convolution_module(
63-
in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups))
74+
(
75+
"conv",
76+
self.convolution_module(
77+
in_channels,
78+
out_channels,
79+
kernel_size,
80+
stride=stride,
81+
padding=padding,
82+
bias=bias,
83+
groups=groups,
84+
),
85+
)
6486
] # if no bn - bias True?
6587
if bn_layer:
6688
bn = self.batchnorm_module(out_channels)
67-
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
68-
layers.append(('bn', bn))
89+
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
90+
layers.append(("bn", bn))
6991
if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False
7092
if pre_act:
7193
act_position = 0
7294
elif not bn_1st:
7395
act_position = 1
7496
else:
7597
act_position = len(layers)
76-
layers.insert(act_position, ('act_fn', act_fn))
98+
layers.insert(act_position, ("act_fn", act_fn))
7799
super().__init__(OrderedDict(layers))
78100

79101

80102
# NOTE First version. Leaved for backwards compatibility with old blocks, models.
81103
class ConvLayer(nn.Sequential):
82104
"""Basic conv layers block"""
105+
83106
Conv2d = nn.Conv2d
84107

85-
def __init__(self, ni, nf, ks=3, stride=1,
86-
act=True, act_fn=act_fn,
87-
bn_layer=True, bn_1st=True, zero_bn=False,
88-
padding=None, bias=False, groups=1, **kwargs):
108+
def __init__(
109+
self,
110+
ni,
111+
nf,
112+
ks=3,
113+
stride=1,
114+
act=True,
115+
act_fn=act_fn,
116+
bn_layer=True,
117+
bn_1st=True,
118+
zero_bn=False,
119+
padding=None,
120+
bias=False,
121+
groups=1,
122+
**kwargs
123+
):
89124

90125
if padding is None:
91126
padding = ks // 2
92-
layers = [('conv', self.Conv2d(ni, nf, ks, stride=stride,
93-
padding=padding, bias=bias, groups=groups))]
94-
act_bn = [('act_fn', act_fn)] if act else []
127+
layers = [
128+
(
129+
"conv",
130+
self.Conv2d(
131+
ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups
132+
),
133+
)
134+
]
135+
act_bn = [("act_fn", act_fn)] if act else []
95136
if bn_layer:
96137
bn = nn.BatchNorm2d(nf)
97-
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
98-
act_bn += [('bn', bn)]
138+
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
139+
act_bn += [("bn", bn)]
99140
if bn_1st:
100141
act_bn.reverse()
101142
layers += act_bn
102143
super().__init__(OrderedDict(layers))
103144

145+
104146
# Cell
105147
# SA module from mxresnet at fastai. todo - add persons!!!
106148
# Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
107149

108150

109-
def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
151+
def conv1d(
152+
ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False
153+
):
110154
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
111155
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
112156
nn.init.kaiming_normal_(conv.weight)
@@ -116,16 +160,16 @@ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bia
116160

117161

118162
class SimpleSelfAttention(nn.Module):
119-
'''SimpleSelfAttention module. # noqa W291
120-
Adapted from SelfAttention layer at
121-
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
122-
Inspired by https://arxiv.org/pdf/1805.08318.pdf
123-
'''
163+
"""SimpleSelfAttention module. # noqa W291
164+
Adapted from SelfAttention layer at
165+
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
166+
Inspired by https://arxiv.org/pdf/1805.08318.pdf
167+
"""
124168

125169
def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
126170
super().__init__()
127171
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias)
128-
self.gamma = torch.nn.Parameter(torch.tensor([0.])) # type: ignore
172+
self.gamma = torch.nn.Parameter(torch.tensor([0.0])) # type: ignore
129173
self.sym = sym
130174
self.n_in = n_in
131175

@@ -136,12 +180,14 @@ def forward(self, x):
136180
c = (c + c.t()) / 2
137181
self.conv.weight = c.view(self.n_in, self.n_in, 1)
138182
size = x.size()
139-
x = x.view(*size[:2], -1) # (C,N)
183+
x = x.view(*size[:2], -1) # (C,N)
140184
# changed the order of multiplication to avoid O(N^2) complexity
141185
# (x*xT)*(W*x) instead of (x*(xT*(W*x)))
142-
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
143-
xxT = torch.bmm(x, x.permute(0, 2, 1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)
144-
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
186+
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
187+
xxT = torch.bmm(
188+
x, x.permute(0, 2, 1).contiguous()
189+
) # (C,N) * (N,C) = (C,C) => O(NC^2)
190+
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
145191
o = self.gamma * o + x
146192
return o.view(*size).contiguous()
147193

@@ -157,11 +203,15 @@ def __init__(self, c, r=16):
157203
ch = max(c // r, 1)
158204
self.squeeze = nn.AdaptiveAvgPool2d(1)
159205
self.excitation = nn.Sequential(
160-
OrderedDict([('fc_reduce', self.se_layer(c, ch, bias=self.use_bias)),
161-
('se_act', self.act_fn),
162-
('fc_expand', self.se_layer(ch, c, bias=self.use_bias)),
163-
('sigmoid', nn.Sigmoid())
164-
]))
206+
OrderedDict(
207+
[
208+
("fc_reduce", self.se_layer(c, ch, bias=self.use_bias)),
209+
("se_act", self.act_fn),
210+
("fc_expand", self.se_layer(ch, c, bias=self.use_bias)),
211+
("sigmoid", nn.Sigmoid()),
212+
]
213+
)
214+
)
165215

166216
def forward(self, x):
167217
bs, c, _, _ = x.shape
@@ -178,16 +228,19 @@ class SEBlockConv(nn.Module): # todo: deprecation warning.
178228

179229
def __init__(self, c, r=16):
180230
super().__init__()
181-
# c_in = math.ceil(c//r/8)*8
231+
# c_in = math.ceil(c//r/8)*8
182232
c_in = max(c // r, 1)
183233
self.squeeze = nn.AdaptiveAvgPool2d(1)
184234
self.excitation = nn.Sequential(
185-
OrderedDict([
186-
('conv_reduce', self.se_layer(c, c_in, 1, bias=self.use_bias)),
187-
('se_act', self.act_fn),
188-
('conv_expand', self.se_layer(c_in, c, 1, bias=self.use_bias)),
189-
('sigmoid', nn.Sigmoid())
190-
]))
235+
OrderedDict(
236+
[
237+
("conv_reduce", self.se_layer(c, c_in, 1, bias=self.use_bias)),
238+
("se_act", self.act_fn),
239+
("conv_expand", self.se_layer(c_in, c, 1, bias=self.use_bias)),
240+
("sigmoid", nn.Sigmoid()),
241+
]
242+
)
243+
)
191244

192245
def forward(self, x):
193246
y = self.squeeze(x)
@@ -198,16 +251,17 @@ def forward(self, x):
198251
class SEModule(nn.Module):
199252
"se block"
200253

201-
def __init__(self,
202-
channels,
203-
reduction=16,
204-
rd_channels=None,
205-
rd_max=False,
206-
se_layer=nn.Linear,
207-
act_fn=nn.ReLU(inplace=True),
208-
use_bias=True,
209-
gate=nn.Sigmoid
210-
):
254+
def __init__(
255+
self,
256+
channels,
257+
reduction=16,
258+
rd_channels=None,
259+
rd_max=False,
260+
se_layer=nn.Linear,
261+
act_fn=nn.ReLU(inplace=True),
262+
use_bias=True,
263+
gate=nn.Sigmoid,
264+
):
211265
super().__init__()
212266
reducted = max(channels // reduction, 1) # preserve zero-element tensors
213267
if rd_channels is None:
@@ -217,11 +271,15 @@ def __init__(self,
217271
rd_channels = max(rd_channels, reducted)
218272
self.squeeze = nn.AdaptiveAvgPool2d(1)
219273
self.excitation = nn.Sequential(
220-
OrderedDict([('reduce', se_layer(channels, rd_channels, bias=use_bias)),
221-
('se_act', act_fn),
222-
('expand', se_layer(rd_channels, channels, bias=use_bias)),
223-
('se_gate', gate())
224-
]))
274+
OrderedDict(
275+
[
276+
("reduce", se_layer(channels, rd_channels, bias=use_bias)),
277+
("se_act", act_fn),
278+
("expand", se_layer(rd_channels, channels, bias=use_bias)),
279+
("se_gate", gate()),
280+
]
281+
)
282+
)
225283

226284
def forward(self, x):
227285
bs, c, _, _ = x.shape
@@ -233,18 +291,19 @@ def forward(self, x):
233291
class SEModuleConv(nn.Module):
234292
"se block with conv on excitation"
235293

236-
def __init__(self,
237-
channels,
238-
reduction=16,
239-
rd_channels=None,
240-
rd_max=False,
241-
se_layer=nn.Conv2d,
242-
act_fn=nn.ReLU(inplace=True),
243-
use_bias=True,
244-
gate=nn.Sigmoid
245-
):
294+
def __init__(
295+
self,
296+
channels,
297+
reduction=16,
298+
rd_channels=None,
299+
rd_max=False,
300+
se_layer=nn.Conv2d,
301+
act_fn=nn.ReLU(inplace=True),
302+
use_bias=True,
303+
gate=nn.Sigmoid,
304+
):
246305
super().__init__()
247-
# rd_channels = math.ceil(channels//reduction/8)*8
306+
# rd_channels = math.ceil(channels//reduction/8)*8
248307
reducted = max(channels // reduction, 1) # preserve zero-element tensors
249308
if rd_channels is None:
250309
rd_channels = reducted
@@ -253,12 +312,15 @@ def __init__(self,
253312
rd_channels = max(rd_channels, reducted)
254313
self.squeeze = nn.AdaptiveAvgPool2d(1)
255314
self.excitation = nn.Sequential(
256-
OrderedDict([
257-
('reduce', se_layer(channels, rd_channels, 1, bias=use_bias)),
258-
('se_act', act_fn),
259-
('expand', se_layer(rd_channels, channels, 1, bias=use_bias)),
260-
('gate', gate())
261-
]))
315+
OrderedDict(
316+
[
317+
("reduce", se_layer(channels, rd_channels, 1, bias=use_bias)),
318+
("se_act", act_fn),
319+
("expand", se_layer(rd_channels, channels, 1, bias=use_bias)),
320+
("gate", gate()),
321+
]
322+
)
323+
)
262324

263325
def forward(self, x):
264326
y = self.squeeze(x)

0 commit comments

Comments
 (0)