Skip to content

Commit 76ae20d

Browse files
committed
black mc
1 parent 8a8eb35 commit 76ae20d

File tree

1 file changed

+150
-65
lines changed

1 file changed

+150
-65
lines changed

model_constructor/model_constructor.py

Lines changed: 150 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,22 @@
77
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
88

99

10-
__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'ModelConstructor', 'xresnet34', 'xresnet50']
10+
__all__ = [
11+
"init_cnn",
12+
"act_fn",
13+
"ResBlock",
14+
"ModelConstructor",
15+
"xresnet34",
16+
"xresnet50",
17+
]
1118

1219

1320
act_fn = nn.ReLU(inplace=True)
1421

1522

1623
def init_cnn(module: nn.Module):
1724
"Init module - kaiming_normal for Conv2d and 0 for biases."
18-
if getattr(module, 'bias', None) is not None:
25+
if getattr(module, "bias", None) is not None:
1926
nn.init.constant_(module.bias, 0) # type: ignore
2027
if isinstance(module, (nn.Conv2d, nn.Linear)):
2128
nn.init.kaiming_normal_(module.weight)
@@ -24,7 +31,7 @@ def init_cnn(module: nn.Module):
2431

2532

2633
class ResBlock(nn.Module):
27-
'''Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.'''
34+
"""Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2835

2936
def __init__(
3037
self,
@@ -49,31 +56,70 @@ def __init__(
4956
if div_groups is not None: # check if groups != 1 and div_groups
5057
groups = int(mid_channels / div_groups)
5158
if expansion == 1:
52-
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride, # type: ignore
53-
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
54-
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
55-
act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
56-
]
59+
layers = [
60+
("conv_0", conv_layer(
61+
in_channels,
62+
mid_channels,
63+
3,
64+
stride=stride, # type: ignore
65+
act_fn=act_fn,
66+
bn_1st=bn_1st,
67+
groups=in_channels if dw else groups,
68+
),),
69+
("conv_1", conv_layer(
70+
mid_channels,
71+
out_channels,
72+
3,
73+
zero_bn=zero_bn,
74+
act_fn=False,
75+
bn_1st=bn_1st,
76+
groups=mid_channels if dw else groups,
77+
),),
78+
]
5779
else:
58-
layers = [("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
59-
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
60-
groups=mid_channels if dw else groups)),
61-
("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st)) # noqa E501
62-
]
80+
layers = [
81+
("conv_0", conv_layer(
82+
in_channels,
83+
mid_channels,
84+
1,
85+
act_fn=act_fn,
86+
bn_1st=bn_1st,
87+
),),
88+
("conv_1", conv_layer(
89+
mid_channels,
90+
mid_channels,
91+
3,
92+
stride=stride,
93+
act_fn=act_fn,
94+
bn_1st=bn_1st,
95+
groups=mid_channels if dw else groups,
96+
),),
97+
("conv_2", conv_layer(
98+
mid_channels,
99+
out_channels,
100+
1,
101+
zero_bn=zero_bn,
102+
act_fn=False,
103+
bn_1st=bn_1st,
104+
),), # noqa E501
105+
]
63106
if se:
64-
layers.append(('se', se(out_channels)))
107+
layers.append(("se", se(out_channels)))
65108
if sa:
66-
layers.append(('sa', sa(out_channels)))
109+
layers.append(("sa", sa(out_channels)))
67110
self.convs = nn.Sequential(OrderedDict(layers))
68111
if stride != 1 or in_channels != out_channels:
69112
id_layers = []
70113
if stride != 1 and pool is not None: # if pool - reduce by pool else stride 2 art id_conv
71114
id_layers.append(("pool", pool))
72115
if in_channels != out_channels or (stride != 1 and pool is None):
73116
id_layers += [("id_conv", conv_layer(
74-
in_channels, out_channels, 1,
117+
in_channels,
118+
out_channels,
119+
1,
75120
stride=1 if pool else stride,
76-
act_fn=False))]
121+
act_fn=False,
122+
),)]
77123
self.id_conv = nn.Sequential(OrderedDict(id_layers))
78124
else:
79125
self.id_conv = None
@@ -85,15 +131,23 @@ def forward(self, x):
85131

86132

87133
def _make_stem(self):
88-
stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i + 1],
89-
stride=2 if i == self.stem_stride_on else 1,
90-
bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True,
91-
act_fn=self.act_fn, bn_1st=self.bn_1st))
92-
for i in range(len(self.stem_sizes) - 1)]
134+
stem = [
135+
(f"conv_{i}", self.conv_layer(
136+
self.stem_sizes[i],
137+
self.stem_sizes[i + 1],
138+
stride=2 if i == self.stem_stride_on else 1,
139+
bn_layer=(not self.stem_bn_end)
140+
if i == (len(self.stem_sizes) - 2)
141+
else True,
142+
act_fn=self.act_fn,
143+
bn_1st=self.bn_1st,
144+
),)
145+
for i in range(len(self.stem_sizes) - 1)
146+
]
93147
if self.stem_pool:
94-
stem.append(('stem_pool', self.stem_pool))
148+
stem.append(("stem_pool", self.stem_pool))
95149
if self.stem_bn_end:
96-
stem.append(('norm', self.norm(self.stem_sizes[-1])))
150+
stem.append(("norm", self.norm(self.stem_sizes[-1])))
97151
return nn.Sequential(OrderedDict(stem))
98152

99153

@@ -102,43 +156,67 @@ def _make_layer(self, layer_num: int) -> nn.Module:
102156
# if no pool on stem - stride = 2 for first layer block in body
103157
stride = 1 if self.stem_pool and layer_num == 0 else 2
104158
num_blocks = self.layers[layer_num]
105-
return nn.Sequential(OrderedDict([
106-
(f"bl_{block_num}", self.block(
107-
self.expansion,
108-
self.block_sizes[layer_num] if block_num == 0 else self.block_sizes[layer_num + 1],
109-
self.block_sizes[layer_num + 1],
110-
stride if block_num == 0 else 1,
111-
sa=self.sa if (block_num == num_blocks - 1) and layer_num == 0 else None,
112-
conv_layer=self.conv_layer,
113-
act_fn=self.act_fn,
114-
pool=self.pool,
115-
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
116-
groups=self.groups, div_groups=self.div_groups,
117-
dw=self.dw, se=self.se
118-
))
119-
for block_num in range(num_blocks)
120-
]))
159+
return nn.Sequential(
160+
OrderedDict(
161+
[
162+
(
163+
f"bl_{block_num}",
164+
self.block(
165+
self.expansion,
166+
self.block_sizes[layer_num]
167+
if block_num == 0
168+
else self.block_sizes[layer_num + 1],
169+
self.block_sizes[layer_num + 1],
170+
stride if block_num == 0 else 1,
171+
sa=self.sa
172+
if (block_num == num_blocks - 1) and layer_num == 0
173+
else None,
174+
conv_layer=self.conv_layer,
175+
act_fn=self.act_fn,
176+
pool=self.pool,
177+
zero_bn=self.zero_bn,
178+
bn_1st=self.bn_1st,
179+
groups=self.groups,
180+
div_groups=self.div_groups,
181+
dw=self.dw,
182+
se=self.se,
183+
),
184+
)
185+
for block_num in range(num_blocks)
186+
]
187+
)
188+
)
121189

122190

123191
def _make_body(self):
124-
return nn.Sequential(OrderedDict([
125-
(f"l_{layer_num}", self._make_layer(self, layer_num))
126-
for layer_num in range(len(self.layers))
127-
]))
192+
return nn.Sequential(
193+
OrderedDict(
194+
[
195+
(
196+
f"l_{layer_num}",
197+
self._make_layer(self, layer_num)
198+
)
199+
for layer_num in range(len(self.layers))
200+
]
201+
)
202+
)
128203

129204

130205
def _make_head(self):
131-
head = [('pool', nn.AdaptiveAvgPool2d(1)),
132-
('flat', nn.Flatten()),
133-
('fc', nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes))]
206+
head = [
207+
("pool", nn.AdaptiveAvgPool2d(1)),
208+
("flat", nn.Flatten()),
209+
("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)),
210+
]
134211
return nn.Sequential(OrderedDict(head))
135212

136213

137-
class ModelConstructor():
214+
class ModelConstructor:
138215
"""Model constructor. As default - xresnet18"""
216+
139217
def __init__(
140218
self,
141-
name: str = 'MC',
219+
name: str = "MC",
142220
in_chans: int = 3,
143221
num_classes: int = 1000,
144222
block=ResBlock,
@@ -221,7 +299,9 @@ def __init__(
221299
else:
222300
self.sa = sa
223301
if se_module or se_reduction: # pragma: no cover
224-
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation warning.
302+
print(
303+
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
304+
) # add deprecation warning.
225305

226306
@property
227307
def block_sizes(self):
@@ -240,23 +320,28 @@ def body(self):
240320
return self._make_body(self)
241321

242322
def __call__(self):
243-
model = nn.Sequential(OrderedDict([
244-
('stem', self.stem),
245-
('body', self.body),
246-
('head', self.head)]))
323+
model = nn.Sequential(
324+
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
325+
)
247326
self._init_cnn(model)
248327
model.extra_repr = lambda: f"{self.name}"
249328
return model
250329

251330
def __repr__(self):
252-
return (f"{self.name} constructor\n"
253-
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
254-
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
255-
f" sa: {self.sa}, se: {self.se}\n"
256-
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
257-
f" body sizes {self._block_sizes}\n"
258-
f" layers: {self.layers}")
259-
260-
261-
xresnet34 = partial(ModelConstructor, name='xresnet34', expansion=1, layers=[3, 4, 6, 3])
262-
xresnet50 = partial(ModelConstructor, name='xresnet34', expansion=4, layers=[3, 4, 6, 3])
331+
return (
332+
f"{self.name} constructor\n"
333+
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
334+
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
335+
f" sa: {self.sa}, se: {self.se}\n"
336+
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
337+
f" body sizes {self._block_sizes}\n"
338+
f" layers: {self.layers}"
339+
)
340+
341+
342+
xresnet34 = partial(
343+
ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
344+
)
345+
xresnet50 = partial(
346+
ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
347+
)

0 commit comments

Comments
 (0)