Skip to content

Commit 2272b58

Browse files
committed
nn_seq, some fixes
1 parent 91947eb commit 2272b58

File tree

4 files changed

+50
-49
lines changed

4 files changed

+50
-49
lines changed

src/model_constructor/helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from collections import OrderedDict
2+
3+
from torch import nn
4+
5+
6+
def nn_seq(list_of_tuples: list[tuple[str, nn.Module]]) -> nn.Sequential:
7+
"""return nn.Sequential from OrderedDict from list of tuples"""
8+
return nn.Sequential(OrderedDict(list_of_tuples))

src/model_constructor/layers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ def forward(self, x):
195195
return o.view(*size).contiguous()
196196

197197

198-
class SEBlock(nn.Module): # todo: deprecation warning.
199-
"se block"
198+
class SEBlock(nn.Module):
199+
"""se block"""
200+
# first version
200201
se_layer = nn.Linear
201202
act_fn = nn.ReLU(inplace=True)
202203
use_bias = True
@@ -223,8 +224,9 @@ def forward(self, x):
223224
return x * y.expand_as(x)
224225

225226

226-
class SEBlockConv(nn.Module): # todo: deprecation warning.
227-
"se block with conv on excitation"
227+
class SEBlockConv(nn.Module):
228+
"""se block with conv on excitation"""
229+
# first version
228230
se_layer = nn.Conv2d
229231
act_fn = nn.ReLU(inplace=True)
230232
use_bias = True

src/model_constructor/model_constructor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class ModelCfg(BaseModel):
272272
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
273273
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
274274

275-
class Config:
275+
class Config: # pylint: disable=too-few-public-methods
276276
arbitrary_types_allowed = True
277277
extra = "forbid"
278278

src/model_constructor/universal_blocks.py

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1-
from collections import OrderedDict
21
from typing import Callable, Union
32

43
from torch import nn
54

5+
from .helpers import nn_seq
66
from .layers import ConvBnAct, get_act
77
from .model_constructor import ModelCfg, ModelConstructor
88

99
__all__ = [
1010
"XResBlock",
11-
"ModelConstructor",
1211
"XResNet34",
1312
"XResNet50",
13+
"YaResNet",
14+
"YaResNet34",
15+
"YaResNet50",
1416
]
1517

1618

17-
# TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
18-
19-
2019
class XResBlock(nn.Module):
2120
"""Universal XResnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2221

@@ -109,7 +108,7 @@ def __init__(
109108
layers.append(("se", se(out_channels)))
110109
if sa:
111110
layers.append(("sa", sa(out_channels)))
112-
self.convs = nn.Sequential(OrderedDict(layers))
111+
self.convs = nn_seq(layers)
113112
if stride != 1 or in_channels != out_channels:
114113
id_layers = []
115114
if (
@@ -129,7 +128,7 @@ def __init__(
129128
),
130129
)
131130
]
132-
self.id_conv = nn.Sequential(OrderedDict(id_layers))
131+
self.id_conv = nn_seq(id_layers)
133132
else:
134133
self.id_conv = None
135134
self.act_fn = get_act(act_fn)
@@ -240,7 +239,7 @@ def __init__(
240239
layers.append(("se", se(out_channels))) # type: ignore
241240
if sa:
242241
layers.append(("sa", sa(out_channels))) # type: ignore
243-
self.convs = nn.Sequential(OrderedDict(layers))
242+
self.convs = nn_seq(layers)
244243
if in_channels != out_channels:
245244
self.id_conv = conv_layer(
246245
in_channels,
@@ -281,7 +280,7 @@ def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
281280
stem.append(("stem_pool", cfg.stem_pool()))
282281
if cfg.stem_bn_end:
283282
stem.append(("norm", cfg.norm(cfg.stem_sizes[-1]))) # type: ignore
284-
return nn.Sequential(OrderedDict(stem))
283+
return nn_seq(stem)
285284

286285

287286
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
@@ -290,47 +289,39 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
290289
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
291290
num_blocks = cfg.layers[layer_num]
292291
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
293-
return nn.Sequential(
294-
OrderedDict(
295-
[
296-
(
297-
f"bl_{block_num}",
298-
cfg.block(
299-
cfg.expansion, # type: ignore
300-
block_chs[layer_num]
301-
if block_num == 0
302-
else block_chs[layer_num + 1],
303-
block_chs[layer_num + 1],
304-
stride if block_num == 0 else 1,
305-
sa=cfg.sa
306-
if (block_num == num_blocks - 1) and layer_num == 0
307-
else None,
308-
conv_layer=cfg.conv_layer,
309-
act_fn=cfg.act_fn,
310-
pool=cfg.pool,
311-
zero_bn=cfg.zero_bn,
312-
bn_1st=cfg.bn_1st,
313-
groups=cfg.groups,
314-
div_groups=cfg.div_groups,
315-
dw=cfg.dw,
316-
se=cfg.se,
317-
),
318-
)
319-
for block_num in range(num_blocks)
320-
]
292+
return nn_seq(
293+
(
294+
f"bl_{block_num}",
295+
cfg.block(
296+
cfg.expansion, # type: ignore
297+
block_chs[layer_num]
298+
if block_num == 0
299+
else block_chs[layer_num + 1],
300+
block_chs[layer_num + 1],
301+
stride if block_num == 0 else 1,
302+
sa=cfg.sa
303+
if (block_num == num_blocks - 1) and layer_num == 0
304+
else None,
305+
conv_layer=cfg.conv_layer,
306+
act_fn=cfg.act_fn,
307+
pool=cfg.pool,
308+
zero_bn=cfg.zero_bn,
309+
bn_1st=cfg.bn_1st,
310+
groups=cfg.groups,
311+
div_groups=cfg.div_groups,
312+
dw=cfg.dw,
313+
se=cfg.se,
314+
),
321315
)
316+
for block_num in range(num_blocks)
322317
)
323318

324319

325320
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
326321
"""Create model body."""
327-
return nn.Sequential(
328-
OrderedDict(
329-
[
330-
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
331-
for layer_num in range(len(cfg.layers))
332-
]
333-
)
322+
return nn_seq(
323+
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
324+
for layer_num in range(len(cfg.layers))
334325
)
335326

336327

0 commit comments

Comments
 (0)