Skip to content

Commit 851426f

Browse files
committed
fix se, test layers
1 parent c0246df commit 851426f

File tree

3 files changed

+125
-13
lines changed

3 files changed

+125
-13
lines changed

model_constructor/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, n_in: int, ks=1, sym=False):
117117
self.n_in = n_in
118118

119119
def forward(self, x):
120-
if self.sym:
120+
if self.sym: # check ks=3
121121
# symmetry hack by https://github.com/mgrankin
122122
c = self.conv.weight.view(self.n_in, self.n_in)
123123
c = (c + c.t()) / 2
@@ -232,7 +232,7 @@ def __init__(self,
232232
):
233233
super().__init__()
234234
# rd_channels = math.ceil(channels//reduction/8)*8
235-
reducted = channels // reduction
235+
reducted = max(channels // reduction, 1) # preserve zero-element tensors
236236
if rd_channels is None:
237237
rd_channels = reducted
238238
else:

tests/test_block.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
img_size = 16
1111

1212

13-
params = {
14-
"Block": [ResBlock, YaResBlock],
15-
"expansion": [1, 2],
16-
"mid_channels": [8, 16],
17-
"stride": [1, 2],
18-
"pool": [None, nn.AvgPool2d(2, ceil_mode=True)],
19-
"se": [None, SEModule],
20-
"sa": [None, SimpleSelfAttention],
21-
}
13+
params = dict(
14+
Block=[ResBlock, YaResBlock],
15+
expansion=[1, 2],
16+
mid_channels=[8, 16],
17+
stride=[1, 2],
18+
div_groups=[None, 2],
19+
pool=[None, nn.AvgPool2d(2, ceil_mode=True)],
20+
se=[None, SEModule],
21+
sa=[None, SimpleSelfAttention],
22+
)
2223

2324

2425
def value_name(value) -> str:
@@ -41,11 +42,14 @@ def pytest_generate_tests(metafunc):
4142
metafunc.parametrize(key, value, ids=ids_fn(key, value))
4243

4344

44-
def test_block(Block, expansion, mid_channels, stride, pool, se, sa):
45+
def test_block(Block, expansion, mid_channels, stride, div_groups, pool, se, sa):
4546
"""test block"""
4647
in_channels = 8
4748
out_channels = mid_channels * expansion
48-
block = Block(expansion, in_channels, mid_channels, stride, pool=pool, se=se, sa=sa)
49+
block = Block(
50+
expansion, in_channels, mid_channels,
51+
stride, div_groups=div_groups,
52+
pool=pool, se=se, sa=sa)
4953
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
5054
y = block(xb)
5155
out_size = img_size if stride == 1 else img_size // stride

tests/test_layers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from model_constructor.layers import ConvBnAct, Flatten, Noop, SEModule, SEModuleConv, SimpleSelfAttention, noop
5+
6+
7+
bs_test = 4
8+
9+
10+
params = dict(
11+
kernel_size=[3, 1],
12+
stride=[1, 2],
13+
padding=[None, 1],
14+
bias=[False, True],
15+
groups=[1, 2],
16+
# # act_fn=act_fn,
17+
pre_act=[False, True],
18+
bn_layer=[True, False],
19+
bn_1st=[True, False],
20+
zero_bn=[False, True],
21+
# SA
22+
sym=[False, True],
23+
# SE
24+
se_module=[SEModule, SEModuleConv],
25+
reduction=[16, 2],
26+
rd_channels=[None, 2],
27+
rd_max=[False, True],
28+
use_bias=[True, False],
29+
)
30+
31+
32+
def value_name(value) -> str:
33+
name = getattr(value, "__name__", None)
34+
if name is not None:
35+
return name
36+
if isinstance(value, nn.Module):
37+
return value._get_name()
38+
else:
39+
return value
40+
41+
42+
def ids_fn(key, value):
43+
return [f"{key[:2]}_{value_name(v)}" for v in value]
44+
45+
46+
def pytest_generate_tests(metafunc):
47+
for key, value in params.items():
48+
if key in metafunc.fixturenames:
49+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
50+
51+
52+
def test_Flatten():
53+
"""test Flatten"""
54+
flatten = Flatten()
55+
channels = 4
56+
xb = torch.randn(bs_test, channels, channels)
57+
out = flatten(xb)
58+
assert out.shape == torch.Size([bs_test, channels * channels])
59+
60+
61+
def test_noop():
62+
"""test Noop, noop"""
63+
xb = torch.randn(bs_test)
64+
xb_copy = xb.clone().detach()
65+
out = noop(xb)
66+
assert out is xb
67+
assert all(out.eq(xb_copy))
68+
noop_module = Noop()
69+
out = noop_module(xb)
70+
assert out is xb
71+
assert all(out.eq(xb_copy))
72+
73+
74+
def test_ConvBnAct(kernel_size, stride, bias, groups, pre_act, bn_layer, bn_1st, zero_bn):
75+
"""test ConvBnAct"""
76+
in_channels = out_channels = 4
77+
channel_size = 4
78+
block = ConvBnAct(
79+
in_channels, out_channels, kernel_size, stride,
80+
padding=None, bias=bias, groups=groups,
81+
pre_act=pre_act, bn_layer=bn_layer, bn_1st=bn_1st, zero_bn=zero_bn)
82+
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
83+
out = block(xb)
84+
out_size = channel_size
85+
if stride == 2:
86+
out_size = channel_size // stride
87+
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])
88+
89+
90+
def test_SimpleSelfAttention(sym):
91+
"""test SimpleSelfAttention"""
92+
in_channels = 4
93+
kernel_size = 1 # ? can be 3? if so check sym hack.
94+
channel_size = 4
95+
sa = SimpleSelfAttention(in_channels, kernel_size, sym)
96+
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
97+
out = sa(xb)
98+
assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size])
99+
100+
101+
def test_SE(se_module, reduction, rd_channels, rd_max, use_bias):
102+
"""test SE"""
103+
in_channels = 8
104+
channel_size = 4
105+
se = se_module(in_channels, reduction, rd_channels, rd_max, use_bias=use_bias)
106+
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
107+
out = se(xb)
108+
assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size])

0 commit comments

Comments
 (0)