Skip to content

Commit 9a662a1

Browse files
authored
Merge pull request #55 from ayasyrev/mc_init
Mc init
2 parents 7a4f0aa + d8af738 commit 9a662a1

File tree

2 files changed

+400
-188
lines changed

2 files changed

+400
-188
lines changed

model_constructor/layers.py

Lines changed: 161 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
import torch.nn as nn
2-
import torch
3-
from torch.nn.utils import spectral_norm
41
from collections import OrderedDict
2+
from typing import List, Optional, Union
53

4+
import torch
5+
import torch.nn as nn
6+
from torch.nn.utils.spectral_norm import spectral_norm
67

7-
__all__ = ['Flatten', 'noop', 'Noop', 'ConvLayer', 'act_fn',
8-
'conv1d', 'SimpleSelfAttention', 'SEBlock', 'SEBlockConv']
8+
__all__ = [
9+
"Flatten",
10+
"noop",
11+
"Noop",
12+
"ConvLayer",
13+
"act_fn",
14+
"conv1d",
15+
"SimpleSelfAttention",
16+
"SEBlock",
17+
"SEBlockConv",
18+
]
919

1020

1121
class Flatten(nn.Module):
12-
'''flat x to vector'''
22+
"""flat x to vector"""
23+
1324
def __init__(self):
1425
super().__init__()
1526

@@ -18,12 +29,13 @@ def forward(self, x):
1829

1930

2031
def noop(x):
21-
'''Dummy func. Return input'''
32+
"""Dummy func. Return input"""
2233
return x
2334

2435

2536
class Noop(nn.Module):
26-
'''Dummy module'''
37+
"""Dummy module"""
38+
2739
def __init__(self):
2840
super().__init__()
2941

@@ -36,83 +48,128 @@ def forward(self, x):
3648

3749
class ConvBnAct(nn.Sequential):
3850
"""Basic Conv + Bn + Act block"""
51+
3952
convolution_module = nn.Conv2d # can be changed in models like twist.
4053
batchnorm_module = nn.BatchNorm2d
4154

42-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
43-
padding=None, bias=False, groups=1,
44-
act_fn=act_fn, pre_act=False,
45-
bn_layer=True, bn_1st=True, zero_bn=False,
46-
):
55+
def __init__(
56+
self,
57+
in_channels: int,
58+
out_channels: int,
59+
kernel_size: int = 3,
60+
stride: int = 1,
61+
padding: Optional[int] = None,
62+
bias: bool = False,
63+
groups: int = 1,
64+
act_fn: Union[nn.Module, bool] = act_fn,
65+
pre_act: bool = False,
66+
bn_layer: bool = True,
67+
bn_1st: bool = True,
68+
zero_bn: bool = False,
69+
):
4770

4871
if padding is None:
4972
padding = kernel_size // 2
50-
layers = [('conv', self.convolution_module(in_channels, out_channels, kernel_size, stride=stride,
51-
padding=padding, bias=bias, groups=groups))] # if no bn - bias True?
73+
layers: List[tuple[str, nn.Module]] = [
74+
(
75+
"conv",
76+
self.convolution_module(
77+
in_channels,
78+
out_channels,
79+
kernel_size,
80+
stride=stride,
81+
padding=padding,
82+
bias=bias,
83+
groups=groups,
84+
),
85+
)
86+
] # if no bn - bias True?
5287
if bn_layer:
5388
bn = self.batchnorm_module(out_channels)
54-
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
55-
layers.append(('bn', bn))
56-
if act_fn:
89+
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
90+
layers.append(("bn", bn))
91+
if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False
5792
if pre_act:
5893
act_position = 0
5994
elif not bn_1st:
6095
act_position = 1
6196
else:
6297
act_position = len(layers)
63-
layers.insert(act_position, ('act_fn', act_fn))
98+
layers.insert(act_position, ("act_fn", act_fn))
6499
super().__init__(OrderedDict(layers))
65100

66101

67102
# NOTE First version. Leaved for backwards compatibility with old blocks, models.
68103
class ConvLayer(nn.Sequential):
69104
"""Basic conv layers block"""
105+
70106
Conv2d = nn.Conv2d
71107

72-
def __init__(self, ni, nf, ks=3, stride=1,
73-
act=True, act_fn=act_fn,
74-
bn_layer=True, bn_1st=True, zero_bn=False,
75-
padding=None, bias=False, groups=1, **kwargs):
108+
def __init__(
109+
self,
110+
ni,
111+
nf,
112+
ks=3,
113+
stride=1,
114+
act=True,
115+
act_fn=act_fn,
116+
bn_layer=True,
117+
bn_1st=True,
118+
zero_bn=False,
119+
padding=None,
120+
bias=False,
121+
groups=1,
122+
**kwargs
123+
):
76124

77125
if padding is None:
78126
padding = ks // 2
79-
layers = [('conv', self.Conv2d(ni, nf, ks, stride=stride,
80-
padding=padding, bias=bias, groups=groups))]
81-
act_bn = [('act_fn', act_fn)] if act else []
127+
layers = [
128+
(
129+
"conv",
130+
self.Conv2d(
131+
ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups
132+
),
133+
)
134+
]
135+
act_bn = [("act_fn", act_fn)] if act else []
82136
if bn_layer:
83137
bn = nn.BatchNorm2d(nf)
84-
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
85-
act_bn += [('bn', bn)]
138+
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
139+
act_bn += [("bn", bn)]
86140
if bn_1st:
87141
act_bn.reverse()
88142
layers += act_bn
89143
super().__init__(OrderedDict(layers))
90144

145+
91146
# Cell
92147
# SA module from mxresnet at fastai. todo - add persons!!!
93148
# Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
94149

95150

96-
def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
151+
def conv1d(
152+
ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False
153+
):
97154
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
98155
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
99156
nn.init.kaiming_normal_(conv.weight)
100157
if bias:
101-
conv.bias.data.zero_()
158+
conv.bias.data.zero_() # type: ignore
102159
return spectral_norm(conv)
103160

104161

105162
class SimpleSelfAttention(nn.Module):
106-
'''SimpleSelfAttention module. # noqa W291
107-
Adapted from SelfAttention layer at
108-
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
109-
Inspired by https://arxiv.org/pdf/1805.08318.pdf
110-
'''
163+
"""SimpleSelfAttention module. # noqa W291
164+
Adapted from SelfAttention layer at
165+
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
166+
Inspired by https://arxiv.org/pdf/1805.08318.pdf
167+
"""
111168

112169
def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
113170
super().__init__()
114171
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias)
115-
self.gamma = nn.Parameter(torch.tensor([0.]))
172+
self.gamma = torch.nn.Parameter(torch.tensor([0.0])) # type: ignore
116173
self.sym = sym
117174
self.n_in = n_in
118175

@@ -123,17 +180,19 @@ def forward(self, x):
123180
c = (c + c.t()) / 2
124181
self.conv.weight = c.view(self.n_in, self.n_in, 1)
125182
size = x.size()
126-
x = x.view(*size[:2], -1) # (C,N)
183+
x = x.view(*size[:2], -1) # (C,N)
127184
# changed the order of multiplication to avoid O(N^2) complexity
128185
# (x*xT)*(W*x) instead of (x*(xT*(W*x)))
129-
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
130-
xxT = torch.bmm(x, x.permute(0, 2, 1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)
131-
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
186+
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
187+
xxT = torch.bmm(
188+
x, x.permute(0, 2, 1).contiguous()
189+
) # (C,N) * (N,C) = (C,C) => O(NC^2)
190+
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
132191
o = self.gamma * o + x
133192
return o.view(*size).contiguous()
134193

135194

136-
class SEBlock(nn.Module): # todo: deprecation worning.
195+
class SEBlock(nn.Module): # todo: deprecation warning.
137196
"se block"
138197
se_layer = nn.Linear
139198
act_fn = nn.ReLU(inplace=True)
@@ -144,11 +203,15 @@ def __init__(self, c, r=16):
144203
ch = max(c // r, 1)
145204
self.squeeze = nn.AdaptiveAvgPool2d(1)
146205
self.excitation = nn.Sequential(
147-
OrderedDict([('fc_reduce', self.se_layer(c, ch, bias=self.use_bias)),
148-
('se_act', self.act_fn),
149-
('fc_expand', self.se_layer(ch, c, bias=self.use_bias)),
150-
('sigmoid', nn.Sigmoid())
151-
]))
206+
OrderedDict(
207+
[
208+
("fc_reduce", self.se_layer(c, ch, bias=self.use_bias)),
209+
("se_act", self.act_fn),
210+
("fc_expand", self.se_layer(ch, c, bias=self.use_bias)),
211+
("sigmoid", nn.Sigmoid()),
212+
]
213+
)
214+
)
152215

153216
def forward(self, x):
154217
bs, c, _, _ = x.shape
@@ -157,24 +220,27 @@ def forward(self, x):
157220
return x * y.expand_as(x)
158221

159222

160-
class SEBlockConv(nn.Module): # todo: deprecation worning.
223+
class SEBlockConv(nn.Module): # todo: deprecation warning.
161224
"se block with conv on excitation"
162225
se_layer = nn.Conv2d
163226
act_fn = nn.ReLU(inplace=True)
164227
use_bias = True
165228

166229
def __init__(self, c, r=16):
167230
super().__init__()
168-
# c_in = math.ceil(c//r/8)*8
231+
# c_in = math.ceil(c//r/8)*8
169232
c_in = max(c // r, 1)
170233
self.squeeze = nn.AdaptiveAvgPool2d(1)
171234
self.excitation = nn.Sequential(
172-
OrderedDict([
173-
('conv_reduce', self.se_layer(c, c_in, 1, bias=self.use_bias)),
174-
('se_act', self.act_fn),
175-
('conv_expand', self.se_layer(c_in, c, 1, bias=self.use_bias)),
176-
('sigmoid', nn.Sigmoid())
177-
]))
235+
OrderedDict(
236+
[
237+
("conv_reduce", self.se_layer(c, c_in, 1, bias=self.use_bias)),
238+
("se_act", self.act_fn),
239+
("conv_expand", self.se_layer(c_in, c, 1, bias=self.use_bias)),
240+
("sigmoid", nn.Sigmoid()),
241+
]
242+
)
243+
)
178244

179245
def forward(self, x):
180246
y = self.squeeze(x)
@@ -185,16 +251,17 @@ def forward(self, x):
185251
class SEModule(nn.Module):
186252
"se block"
187253

188-
def __init__(self,
189-
channels,
190-
reduction=16,
191-
rd_channels=None,
192-
rd_max=False,
193-
se_layer=nn.Linear,
194-
act_fn=nn.ReLU(inplace=True),
195-
use_bias=True,
196-
gate=nn.Sigmoid
197-
):
254+
def __init__(
255+
self,
256+
channels,
257+
reduction=16,
258+
rd_channels=None,
259+
rd_max=False,
260+
se_layer=nn.Linear,
261+
act_fn=nn.ReLU(inplace=True),
262+
use_bias=True,
263+
gate=nn.Sigmoid,
264+
):
198265
super().__init__()
199266
reducted = max(channels // reduction, 1) # preserve zero-element tensors
200267
if rd_channels is None:
@@ -204,11 +271,15 @@ def __init__(self,
204271
rd_channels = max(rd_channels, reducted)
205272
self.squeeze = nn.AdaptiveAvgPool2d(1)
206273
self.excitation = nn.Sequential(
207-
OrderedDict([('reduce', se_layer(channels, rd_channels, bias=use_bias)),
208-
('se_act', act_fn),
209-
('expand', se_layer(rd_channels, channels, bias=use_bias)),
210-
('se_gate', gate())
211-
]))
274+
OrderedDict(
275+
[
276+
("reduce", se_layer(channels, rd_channels, bias=use_bias)),
277+
("se_act", act_fn),
278+
("expand", se_layer(rd_channels, channels, bias=use_bias)),
279+
("se_gate", gate()),
280+
]
281+
)
282+
)
212283

213284
def forward(self, x):
214285
bs, c, _, _ = x.shape
@@ -220,18 +291,19 @@ def forward(self, x):
220291
class SEModuleConv(nn.Module):
221292
"se block with conv on excitation"
222293

223-
def __init__(self,
224-
channels,
225-
reduction=16,
226-
rd_channels=None,
227-
rd_max=False,
228-
se_layer=nn.Conv2d,
229-
act_fn=nn.ReLU(inplace=True),
230-
use_bias=True,
231-
gate=nn.Sigmoid
232-
):
294+
def __init__(
295+
self,
296+
channels,
297+
reduction=16,
298+
rd_channels=None,
299+
rd_max=False,
300+
se_layer=nn.Conv2d,
301+
act_fn=nn.ReLU(inplace=True),
302+
use_bias=True,
303+
gate=nn.Sigmoid,
304+
):
233305
super().__init__()
234-
# rd_channels = math.ceil(channels//reduction/8)*8
306+
# rd_channels = math.ceil(channels//reduction/8)*8
235307
reducted = max(channels // reduction, 1) # preserve zero-element tensors
236308
if rd_channels is None:
237309
rd_channels = reducted
@@ -240,12 +312,15 @@ def __init__(self,
240312
rd_channels = max(rd_channels, reducted)
241313
self.squeeze = nn.AdaptiveAvgPool2d(1)
242314
self.excitation = nn.Sequential(
243-
OrderedDict([
244-
('reduce', se_layer(channels, rd_channels, 1, bias=use_bias)),
245-
('se_act', act_fn),
246-
('expand', se_layer(rd_channels, channels, 1, bias=use_bias)),
247-
('gate', gate())
248-
]))
315+
OrderedDict(
316+
[
317+
("reduce", se_layer(channels, rd_channels, 1, bias=use_bias)),
318+
("se_act", act_fn),
319+
("expand", se_layer(rd_channels, channels, 1, bias=use_bias)),
320+
("gate", gate()),
321+
]
322+
)
323+
)
249324

250325
def forward(self, x):
251326
y = self.squeeze(x)

0 commit comments

Comments
 (0)