Skip to content

Commit 6e8f3b3

Browse files
committed
Merge branch 'tests' into dev_011
2 parents 5eebcfd + db440bc commit 6e8f3b3

File tree

12 files changed

+544
-21
lines changed

12 files changed

+544
-21
lines changed

model_constructor/convmixer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,5 @@ def __init__(self, dim: int, depth: int,
104104
nn.AdaptiveAvgPool2d((1, 1)),
105105
nn.Flatten(),
106106
nn.Linear(dim, n_classes))
107-
if init_func is not None:
107+
if init_func is not None: # pragma: no cover
108108
init_func(self)

model_constructor/layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,15 @@ class SimpleSelfAttention(nn.Module):
109109
Inspired by https://arxiv.org/pdf/1805.08318.pdf
110110
'''
111111

112-
def __init__(self, n_in: int, ks=1, sym=False):
112+
def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
113113
super().__init__()
114-
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=False)
114+
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias)
115115
self.gamma = nn.Parameter(torch.tensor([0.]))
116116
self.sym = sym
117117
self.n_in = n_in
118118

119119
def forward(self, x):
120-
if self.sym:
120+
if self.sym: # check ks=3
121121
# symmetry hack by https://github.com/mgrankin
122122
c = self.conv.weight.view(self.n_in, self.n_in)
123123
c = (c + c.t()) / 2
@@ -141,7 +141,7 @@ class SEBlock(nn.Module): # todo: deprecation worning.
141141

142142
def __init__(self, c, r=16):
143143
super().__init__()
144-
ch = c // r
144+
ch = max(c // r, 1)
145145
self.squeeze = nn.AdaptiveAvgPool2d(1)
146146
self.excitation = nn.Sequential(
147147
OrderedDict([('fc_reduce', self.se_layer(c, ch, bias=self.use_bias)),
@@ -166,7 +166,7 @@ class SEBlockConv(nn.Module): # todo: deprecation worning.
166166
def __init__(self, c, r=16):
167167
super().__init__()
168168
# c_in = math.ceil(c//r/8)*8
169-
c_in = c // r
169+
c_in = max(c // r, 1)
170170
self.squeeze = nn.AdaptiveAvgPool2d(1)
171171
self.excitation = nn.Sequential(
172172
OrderedDict([
@@ -196,7 +196,7 @@ def __init__(self,
196196
gate=nn.Sigmoid
197197
):
198198
super().__init__()
199-
reducted = channels // reduction
199+
reducted = max(channels // reduction, 1) # preserve zero-element tensors
200200
if rd_channels is None:
201201
rd_channels = reducted
202202
else:
@@ -232,7 +232,7 @@ def __init__(self,
232232
):
233233
super().__init__()
234234
# rd_channels = math.ceil(channels//reduction/8)*8
235-
reducted = channels // reduction
235+
reducted = max(channels // reduction, 1) # preserve zero-element tensors
236236
if rd_channels is None:
237237
rd_channels = reducted
238238
else:

model_constructor/model_constructor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,18 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
5353
if sa:
5454
layers.append(('sa', sa(out_channels)))
5555
self.convs = nn.Sequential(OrderedDict(layers))
56-
id_layers = []
57-
if stride != 1 and pool:
58-
id_layers.append(("pool", pool))
59-
id_layers += [] if in_channels == out_channels else [("id_conv", conv_layer(in_channels, out_channels, 1,
60-
stride=1 if pool else stride,
61-
act_fn=False))]
62-
self.id_conv = None if id_layers == [] else nn.Sequential(OrderedDict(id_layers))
56+
if stride != 1 or in_channels != out_channels:
57+
id_layers = []
58+
if stride != 1 and pool is not None:
59+
id_layers.append(("pool", pool))
60+
if in_channels != out_channels or (stride != 1 and pool is None):
61+
id_layers += [("id_conv", conv_layer(
62+
in_channels, out_channels, 1,
63+
stride=1 if pool else stride,
64+
act_fn=False))]
65+
self.id_conv = nn.Sequential(OrderedDict(id_layers))
66+
else:
67+
self.id_conv = None
6368
self.act_fn = act_fn
6469

6570
def forward(self, x):
@@ -147,7 +152,7 @@ def __init__(self, name='MC', in_chans=3, num_classes=1000,
147152
if self.sa: # if sa=1 or sa=True
148153
if type(self.sa) in (bool, int):
149154
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
150-
if self.se_module or se_reduction:
155+
if self.se_module or se_reduction: # pragma: no cover
151156
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation worning.
152157

153158
@property

model_constructor/net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, expansion, ni, nh, stride=1,
3636
groups = int(nh / div_groups)
3737
if expansion == 1:
3838
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
39-
groups=nh if dw else groups)),
39+
groups=ni if dw else groups)),
4040
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
4141
]
4242
else:
@@ -78,7 +78,7 @@ def __init__(self, expansion, ni, nh, stride=1,
7878
self.reduce = noop if stride == 1 else pool
7979
if expansion == 1:
8080
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
81-
groups=nh if dw else groups)),
81+
groups=ni if dw else groups)),
8282
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
8383
]
8484
else:

model_constructor/yaresnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3030
groups = int(mid_channels / div_groups)
3131
if stride != 1:
3232
if pool is None:
33-
raise Exception("pool not passed")
34-
self.reduce = pool
33+
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
34+
# warnings.warn("pool not passed") # need to warn?
35+
else:
36+
self.reduce = pool
3537
else:
3638
self.reduce = None
3739
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1,
@@ -42,7 +44,8 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
4244
("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
4345
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
4446
groups=mid_channels if dw else groups)),
45-
("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
47+
("conv_2", conv_layer(
48+
mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
4649
]
4750
if se:
4851
layers.append(('se', se(out_channels)))

tests/__init__.py

Whitespace-only changes.

tests/test_Net.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from model_constructor.net import Net, NewResBlock, ResBlock
5+
# from model_constructor.layers import SEModule, SimpleSelfAttention
6+
7+
8+
bs_test = 4
9+
10+
11+
params = dict(
12+
block=[ResBlock, NewResBlock],
13+
expansion=[1, 2],
14+
groups=[1, 2],
15+
dw=[0, 1],
16+
div_groups=[None, 2],
17+
sa=[0, 1],
18+
se=[0, 1],
19+
bn_1st=[True, False],
20+
zero_bn=[True, False],
21+
stem_bn_end=[True, False],
22+
stem_stride_on=[0, 1]
23+
)
24+
25+
26+
def value_name(value) -> str: # pragma: no cover
27+
name = getattr(value, "__name__", None)
28+
if name is not None:
29+
return name
30+
if isinstance(value, nn.Module):
31+
return value._get_name()
32+
else:
33+
return value
34+
35+
36+
def ids_fn(key, value):
37+
return [f"{key[:2]}_{value_name(v)}" for v in value]
38+
39+
40+
def pytest_generate_tests(metafunc):
41+
for key, value in params.items():
42+
if key in metafunc.fixturenames:
43+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
44+
45+
46+
def test_Net(
47+
block, expansion,
48+
groups,
49+
):
50+
"""test Net"""
51+
c_in = 3
52+
img_size = 16
53+
c_out = 8
54+
name = "Test name"
55+
56+
mc = Net(
57+
name, c_in, c_out, block,
58+
expansion=expansion,
59+
stem_sizes=[8, 16],
60+
block_sizes=[16, 32, 64, 128],
61+
groups=groups,
62+
# dw=dw,
63+
# div_groups=div_groups,
64+
# bn_1st=bn_1st, zero_bn=zero_bn,
65+
# stem_bn_end=stem_bn_end,
66+
)
67+
assert f"{name} constructor" in str(mc)
68+
model = mc()
69+
xb = torch.randn(bs_test, c_in, img_size, img_size)
70+
pred = model(xb)
71+
assert pred.shape == torch.Size([bs_test, c_out])
72+
73+
74+
def test_Net_SE_SA(
75+
block, expansion,
76+
se, sa
77+
):
78+
"""test Net"""
79+
c_in = 3
80+
img_size = 16
81+
c_out = 8
82+
name = "Test name"
83+
84+
mc = Net(
85+
name, c_in, c_out, block,
86+
expansion=expansion,
87+
stem_sizes=[8, 16],
88+
block_sizes=[16, 32, 64, 128],
89+
se=se, sa=sa
90+
)
91+
assert f"{name} constructor" in str(mc)
92+
model = mc()
93+
xb = torch.randn(bs_test, c_in, img_size, img_size)
94+
pred = model(xb)
95+
assert pred.shape == torch.Size([bs_test, c_out])
96+
97+
98+
def test_Net_div_gr(
99+
block, expansion,
100+
div_groups,
101+
):
102+
"""test Net"""
103+
c_in = 3
104+
img_size = 16
105+
c_out = 8
106+
name = "Test name"
107+
108+
mc = Net(
109+
name, c_in, c_out, block,
110+
expansion=expansion,
111+
stem_sizes=[8, 16],
112+
block_sizes=[16, 32, 64, 128],
113+
div_groups=div_groups,
114+
)
115+
assert f"{name} constructor" in str(mc)
116+
model = mc()
117+
xb = torch.randn(bs_test, c_in, img_size, img_size)
118+
pred = model(xb)
119+
assert pred.shape == torch.Size([bs_test, c_out])
120+
121+
122+
def test_Net_dw(
123+
block, expansion,
124+
dw
125+
):
126+
"""test Net"""
127+
c_in = 3
128+
img_size = 16
129+
c_out = 8
130+
name = "Test name"
131+
132+
mc = Net(
133+
name, c_in, c_out, block,
134+
expansion=expansion,
135+
stem_sizes=[8, 16],
136+
block_sizes=[16, 32, 64, 128],
137+
dw=dw
138+
)
139+
assert f"{name} constructor" in str(mc)
140+
model = mc()
141+
xb = torch.randn(bs_test, c_in, img_size, img_size)
142+
pred = model(xb)
143+
assert pred.shape == torch.Size([bs_test, c_out])
144+
145+
146+
def test_Net_2(
147+
block, expansion,
148+
bn_1st, zero_bn,
149+
):
150+
"""test Net"""
151+
c_in = 3
152+
img_size = 16
153+
c_out = 8
154+
name = "Test name"
155+
156+
mc = Net(
157+
name, c_in, c_out, block,
158+
expansion=expansion,
159+
stem_sizes=[8, 16],
160+
block_sizes=[16, 32, 64, 128],
161+
bn_1st=bn_1st, zero_bn=zero_bn,
162+
)
163+
assert f"{name} constructor" in str(mc)
164+
model = mc()
165+
xb = torch.randn(bs_test, c_in, img_size, img_size)
166+
pred = model(xb)
167+
assert pred.shape == torch.Size([bs_test, c_out])
168+
169+
170+
def test_Net_stem(
171+
stem_bn_end,
172+
stem_stride_on
173+
):
174+
"""test Net"""
175+
c_in = 3
176+
img_size = 16
177+
c_out = 8
178+
name = "Test name"
179+
180+
mc = Net(
181+
name, c_in, c_out,
182+
stem_sizes=[8, 16],
183+
block_sizes=[16, 32, 64, 128],
184+
stem_bn_end=stem_bn_end,
185+
stem_stride_on=stem_stride_on
186+
)
187+
assert f"{name} constructor" in str(mc)
188+
model = mc()
189+
xb = torch.randn(bs_test, c_in, img_size, img_size)
190+
pred = model(xb)
191+
assert pred.shape == torch.Size([bs_test, c_out])

tests/test_block.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# import pytest
2+
import torch
3+
import torch.nn as nn
4+
from model_constructor.layers import SEModule, SimpleSelfAttention
5+
6+
from model_constructor.model_constructor import ResBlock
7+
from model_constructor.yaresnet import YaResBlock
8+
9+
bs_test = 4
10+
img_size = 16
11+
12+
13+
params = dict(
14+
Block=[ResBlock, YaResBlock],
15+
expansion=[1, 2],
16+
mid_channels=[8, 16],
17+
stride=[1, 2],
18+
div_groups=[None, 2],
19+
pool=[None, nn.AvgPool2d(2, ceil_mode=True)],
20+
se=[None, SEModule],
21+
sa=[None, SimpleSelfAttention],
22+
)
23+
24+
25+
def value_name(value) -> str:
26+
name = getattr(value, "__name__", None)
27+
if name is not None:
28+
return name
29+
if isinstance(value, nn.Module):
30+
return value._get_name()
31+
else:
32+
return value
33+
34+
35+
def ids_fn(key, value):
36+
return [f"{key[:2]}_{value_name(v)}" for v in value]
37+
38+
39+
def pytest_generate_tests(metafunc):
40+
for key, value in params.items():
41+
if key in metafunc.fixturenames:
42+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
43+
44+
45+
def test_block(Block, expansion, mid_channels, stride, div_groups, pool, se, sa):
46+
"""test block"""
47+
in_channels = 8
48+
out_channels = mid_channels * expansion
49+
block = Block(
50+
expansion, in_channels, mid_channels,
51+
stride, div_groups=div_groups,
52+
pool=pool, se=se, sa=sa)
53+
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
54+
y = block(xb)
55+
out_size = img_size if stride == 1 else img_size // stride
56+
assert y.shape == torch.Size([bs_test, out_channels, out_size, out_size])

0 commit comments

Comments
 (0)