Skip to content

Commit c0246df

Browse files
committed
fix blocks, tests blocks
1 parent 194a2b8 commit c0246df

File tree

5 files changed

+71
-11
lines changed

5 files changed

+71
-11
lines changed

model_constructor/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__(self,
196196
gate=nn.Sigmoid
197197
):
198198
super().__init__()
199-
reducted = channels // reduction
199+
reducted = max(channels // reduction, 1) # preserve zero-element tensors
200200
if rd_channels is None:
201201
rd_channels = reducted
202202
else:

model_constructor/model_constructor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,18 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
5353
if sa:
5454
layers.append(('sa', sa(out_channels)))
5555
self.convs = nn.Sequential(OrderedDict(layers))
56-
id_layers = []
57-
if stride != 1 and pool:
58-
id_layers.append(("pool", pool))
59-
id_layers += [] if in_channels == out_channels else [("id_conv", conv_layer(in_channels, out_channels, 1,
60-
stride=1 if pool else stride,
61-
act_fn=False))]
62-
self.id_conv = None if id_layers == [] else nn.Sequential(OrderedDict(id_layers))
56+
if stride != 1 or in_channels != out_channels:
57+
id_layers = []
58+
if stride != 1 and pool is not None:
59+
id_layers.append(("pool", pool))
60+
if in_channels != out_channels or (stride != 1 and pool is None):
61+
id_layers += [("id_conv", conv_layer(
62+
in_channels, out_channels, 1,
63+
stride=1 if pool else stride,
64+
act_fn=False))]
65+
self.id_conv = nn.Sequential(OrderedDict(id_layers))
66+
else:
67+
self.id_conv = None
6368
self.act_fn = act_fn
6469

6570
def forward(self, x):

model_constructor/yaresnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3030
groups = int(mid_channels / div_groups)
3131
if stride != 1:
3232
if pool is None:
33-
raise Exception("pool not passed")
34-
self.reduce = pool
33+
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
34+
# warnings.warn("pool not passed") # need to warn?
35+
else:
36+
self.reduce = pool
3537
else:
3638
self.reduce = None
3739
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1,
@@ -42,7 +44,8 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
4244
("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
4345
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
4446
groups=mid_channels if dw else groups)),
45-
("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
47+
("conv_2", conv_layer(
48+
mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
4649
]
4750
if se:
4851
layers.append(('se', se(out_channels)))

tests/__init__.py

Whitespace-only changes.

tests/test_block.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# import pytest
2+
import torch
3+
import torch.nn as nn
4+
from model_constructor.layers import SEModule, SimpleSelfAttention
5+
6+
from model_constructor.model_constructor import ResBlock
7+
from model_constructor.yaresnet import YaResBlock
8+
9+
bs_test = 4
10+
img_size = 16
11+
12+
13+
params = {
14+
"Block": [ResBlock, YaResBlock],
15+
"expansion": [1, 2],
16+
"mid_channels": [8, 16],
17+
"stride": [1, 2],
18+
"pool": [None, nn.AvgPool2d(2, ceil_mode=True)],
19+
"se": [None, SEModule],
20+
"sa": [None, SimpleSelfAttention],
21+
}
22+
23+
24+
def value_name(value) -> str:
25+
name = getattr(value, "__name__", None)
26+
if name is not None:
27+
return name
28+
if isinstance(value, nn.Module):
29+
return value._get_name()
30+
else:
31+
return value
32+
33+
34+
def ids_fn(key, value):
35+
return [f"{key[:2]}_{value_name(v)}" for v in value]
36+
37+
38+
def pytest_generate_tests(metafunc):
39+
for key, value in params.items():
40+
if key in metafunc.fixturenames:
41+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
42+
43+
44+
def test_block(Block, expansion, mid_channels, stride, pool, se, sa):
45+
"""test block"""
46+
in_channels = 8
47+
out_channels = mid_channels * expansion
48+
block = Block(expansion, in_channels, mid_channels, stride, pool=pool, se=se, sa=sa)
49+
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
50+
y = block(xb)
51+
out_size = img_size if stride == 1 else img_size // stride
52+
assert y.shape == torch.Size([bs_test, out_channels, out_size, out_size])

0 commit comments

Comments
 (0)