Skip to content

Commit 91947eb

Browse files
committed
move universal blocks to module
1 parent 2cdd6a0 commit 91947eb

File tree

6 files changed

+395
-9
lines changed

6 files changed

+395
-9
lines changed

src/model_constructor/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ def __init__(
114114
nf,
115115
ks=3,
116116
stride=1,
117-
act=True,
117+
act=True, # pylint: disable=redefined-outer-name
118118
act_fn=act,
119119
bn_layer=True,
120120
bn_1st=True,
121121
zero_bn=False,
122122
padding=None,
123123
bias=False,
124124
groups=1,
125-
**kwargs
125+
**kwargs # pylint: disable=unused-argument
126126
):
127127

128128
if padding is None:
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
from collections import OrderedDict
2+
from typing import Callable, Union
3+
4+
from torch import nn
5+
6+
from .layers import ConvBnAct, get_act
7+
from .model_constructor import ModelCfg, ModelConstructor
8+
9+
__all__ = [
10+
"XResBlock",
11+
"ModelConstructor",
12+
"XResNet34",
13+
"XResNet50",
14+
]
15+
16+
17+
# TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
18+
19+
20+
class XResBlock(nn.Module):
21+
"""Universal XResnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
22+
23+
def __init__(
24+
self,
25+
expansion: int,
26+
in_channels: int,
27+
mid_channels: int,
28+
stride: int = 1,
29+
conv_layer=ConvBnAct,
30+
act_fn: type[nn.Module] = nn.ReLU,
31+
zero_bn: bool = True,
32+
bn_1st: bool = True,
33+
groups: int = 1,
34+
dw: bool = False,
35+
div_groups: Union[None, int] = None,
36+
pool: Union[Callable[[], nn.Module], None] = None,
37+
se: Union[nn.Module, None] = None,
38+
sa: Union[nn.Module, None] = None,
39+
):
40+
super().__init__()
41+
# pool defined at ModelConstructor.
42+
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
43+
if div_groups is not None: # check if groups != 1 and div_groups
44+
groups = int(mid_channels / div_groups)
45+
if expansion == 1:
46+
layers = [
47+
(
48+
"conv_0",
49+
conv_layer(
50+
in_channels,
51+
mid_channels,
52+
3,
53+
stride=stride, # type: ignore
54+
act_fn=act_fn,
55+
bn_1st=bn_1st,
56+
groups=in_channels if dw else groups,
57+
),
58+
),
59+
(
60+
"conv_1",
61+
conv_layer(
62+
mid_channels,
63+
out_channels,
64+
3,
65+
zero_bn=zero_bn,
66+
act_fn=False,
67+
bn_1st=bn_1st,
68+
groups=mid_channels if dw else groups,
69+
),
70+
),
71+
]
72+
else:
73+
layers = [
74+
(
75+
"conv_0",
76+
conv_layer(
77+
in_channels,
78+
mid_channels,
79+
1,
80+
act_fn=act_fn,
81+
bn_1st=bn_1st,
82+
),
83+
),
84+
(
85+
"conv_1",
86+
conv_layer(
87+
mid_channels,
88+
mid_channels,
89+
3,
90+
stride=stride,
91+
act_fn=act_fn,
92+
bn_1st=bn_1st,
93+
groups=mid_channels if dw else groups,
94+
),
95+
),
96+
(
97+
"conv_2",
98+
conv_layer(
99+
mid_channels,
100+
out_channels,
101+
1,
102+
zero_bn=zero_bn,
103+
act_fn=False,
104+
bn_1st=bn_1st,
105+
),
106+
), # noqa E501
107+
]
108+
if se:
109+
layers.append(("se", se(out_channels)))
110+
if sa:
111+
layers.append(("sa", sa(out_channels)))
112+
self.convs = nn.Sequential(OrderedDict(layers))
113+
if stride != 1 or in_channels != out_channels:
114+
id_layers = []
115+
if (
116+
stride != 1 and pool is not None
117+
): # if pool - reduce by pool else stride 2 art id_conv
118+
id_layers.append(("pool", pool()))
119+
if in_channels != out_channels or (stride != 1 and pool is None):
120+
id_layers += [
121+
(
122+
"id_conv",
123+
conv_layer(
124+
in_channels,
125+
out_channels,
126+
1,
127+
stride=1 if pool else stride,
128+
act_fn=False,
129+
),
130+
)
131+
]
132+
self.id_conv = nn.Sequential(OrderedDict(id_layers))
133+
else:
134+
self.id_conv = None
135+
self.act_fn = get_act(act_fn)
136+
137+
def forward(self, x):
138+
identity = self.id_conv(x) if self.id_conv is not None else x
139+
return self.act_fn(self.convs(x) + identity)
140+
141+
142+
class YaResBlock(nn.Module):
143+
"""YaResBlock. Reduce by pool instead of stride 2"""
144+
145+
def __init__(
146+
self,
147+
expansion: int,
148+
in_channels: int,
149+
mid_channels: int,
150+
stride: int = 1,
151+
conv_layer=ConvBnAct,
152+
act_fn: type[nn.Module] = nn.ReLU,
153+
zero_bn: bool = True,
154+
bn_1st: bool = True,
155+
groups: int = 1,
156+
dw: bool = False,
157+
div_groups: Union[None, int] = None,
158+
pool: Union[Callable[[], nn.Module], None] = None,
159+
se: Union[type[nn.Module], None] = None,
160+
sa: Union[type[nn.Module], None] = None,
161+
):
162+
super().__init__()
163+
# pool defined at ModelConstructor.
164+
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
165+
if div_groups is not None: # check if groups != 1 and div_groups
166+
groups = int(mid_channels / div_groups)
167+
168+
if stride != 1:
169+
if pool is None:
170+
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
171+
# warnings.warn("pool not passed") # need to warn?
172+
else:
173+
self.reduce = pool()
174+
else:
175+
self.reduce = None
176+
if expansion == 1:
177+
layers = [
178+
(
179+
"conv_0",
180+
conv_layer(
181+
in_channels,
182+
mid_channels,
183+
3,
184+
stride=1,
185+
act_fn=act_fn,
186+
bn_1st=bn_1st,
187+
groups=in_channels if dw else groups,
188+
),
189+
),
190+
(
191+
"conv_1",
192+
conv_layer(
193+
mid_channels,
194+
out_channels,
195+
3,
196+
zero_bn=zero_bn,
197+
act_fn=False,
198+
bn_1st=bn_1st,
199+
groups=mid_channels if dw else groups,
200+
),
201+
),
202+
]
203+
else:
204+
layers = [
205+
(
206+
"conv_0",
207+
conv_layer(
208+
in_channels,
209+
mid_channels,
210+
1,
211+
act_fn=act_fn,
212+
bn_1st=bn_1st,
213+
),
214+
),
215+
(
216+
"conv_1",
217+
conv_layer(
218+
mid_channels,
219+
mid_channels,
220+
3,
221+
stride=1,
222+
act_fn=act_fn,
223+
bn_1st=bn_1st,
224+
groups=mid_channels if dw else groups,
225+
),
226+
),
227+
(
228+
"conv_2",
229+
conv_layer(
230+
mid_channels,
231+
out_channels,
232+
1,
233+
zero_bn=zero_bn,
234+
act_fn=False,
235+
bn_1st=bn_1st,
236+
),
237+
), # noqa E501
238+
]
239+
if se:
240+
layers.append(("se", se(out_channels))) # type: ignore
241+
if sa:
242+
layers.append(("sa", sa(out_channels))) # type: ignore
243+
self.convs = nn.Sequential(OrderedDict(layers))
244+
if in_channels != out_channels:
245+
self.id_conv = conv_layer(
246+
in_channels,
247+
out_channels,
248+
1,
249+
stride=1,
250+
act_fn=False,
251+
)
252+
else:
253+
self.id_conv = None
254+
self.merge = get_act(act_fn)
255+
256+
def forward(self, x):
257+
if self.reduce:
258+
x = self.reduce(x)
259+
identity = self.id_conv(x) if self.id_conv is not None else x
260+
return self.merge(self.convs(x) + identity)
261+
262+
263+
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
264+
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
265+
len_stem = len(cfg.stem_sizes)
266+
stem: list[tuple[str, nn.Module]] = [
267+
(
268+
f"conv_{i}",
269+
cfg.conv_layer(
270+
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
271+
cfg.stem_sizes[i],
272+
stride=2 if i == cfg.stem_stride_on else 1,
273+
bn_layer=(not cfg.stem_bn_end) if i == (len_stem - 1) else True,
274+
act_fn=cfg.act_fn,
275+
bn_1st=cfg.bn_1st,
276+
),
277+
)
278+
for i in range(len_stem)
279+
]
280+
if cfg.stem_pool:
281+
stem.append(("stem_pool", cfg.stem_pool()))
282+
if cfg.stem_bn_end:
283+
stem.append(("norm", cfg.norm(cfg.stem_sizes[-1]))) # type: ignore
284+
return nn.Sequential(OrderedDict(stem))
285+
286+
287+
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
288+
"""Create layer (stage)"""
289+
# if no pool on stem - stride = 2 for first layer block in body
290+
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
291+
num_blocks = cfg.layers[layer_num]
292+
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
293+
return nn.Sequential(
294+
OrderedDict(
295+
[
296+
(
297+
f"bl_{block_num}",
298+
cfg.block(
299+
cfg.expansion, # type: ignore
300+
block_chs[layer_num]
301+
if block_num == 0
302+
else block_chs[layer_num + 1],
303+
block_chs[layer_num + 1],
304+
stride if block_num == 0 else 1,
305+
sa=cfg.sa
306+
if (block_num == num_blocks - 1) and layer_num == 0
307+
else None,
308+
conv_layer=cfg.conv_layer,
309+
act_fn=cfg.act_fn,
310+
pool=cfg.pool,
311+
zero_bn=cfg.zero_bn,
312+
bn_1st=cfg.bn_1st,
313+
groups=cfg.groups,
314+
div_groups=cfg.div_groups,
315+
dw=cfg.dw,
316+
se=cfg.se,
317+
),
318+
)
319+
for block_num in range(num_blocks)
320+
]
321+
)
322+
)
323+
324+
325+
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
326+
"""Create model body."""
327+
return nn.Sequential(
328+
OrderedDict(
329+
[
330+
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
331+
for layer_num in range(len(cfg.layers))
332+
]
333+
)
334+
)
335+
336+
337+
class XResNet(ModelConstructor):
338+
"""Base Xresnet constructor."""
339+
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem
340+
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer
341+
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body
342+
block: type[nn.Module] = XResBlock
343+
344+
345+
class XResNet34(XResNet):
346+
layers: list[int] = [3, 4, 6, 3]
347+
348+
349+
class XResNet50(XResNet34):
350+
expansion: int = 4
351+
352+
353+
class YaResNet(XResNet):
354+
"""Base Yaresnet constructor.
355+
YaResBlock, Mish activation, custom stem.
356+
"""
357+
block: type[nn.Module] = YaResBlock
358+
stem_sizes: list[int] = [3, 32, 64, 64]
359+
act_fn: type[nn.Module] = nn.Mish
360+
361+
362+
class YaResNet34(YaResNet):
363+
layers: list[int] = [3, 4, 6, 3]
364+
365+
366+
class YaResNet50(YaResNet34):
367+
expansion: int = 4

0 commit comments

Comments
 (0)