Skip to content

Commit 883d2df

Browse files
committed
tests Net
1 parent 76e3257 commit 883d2df

File tree

9 files changed

+287
-11
lines changed

9 files changed

+287
-11
lines changed

model_constructor/convmixer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,5 @@ def __init__(self, dim: int, depth: int,
104104
nn.AdaptiveAvgPool2d((1, 1)),
105105
nn.Flatten(),
106106
nn.Linear(dim, n_classes))
107-
if init_func is not None:
107+
if init_func is not None: # pragma: no cover
108108
init_func(self)

model_constructor/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ class SimpleSelfAttention(nn.Module):
109109
Inspired by https://arxiv.org/pdf/1805.08318.pdf
110110
'''
111111

112-
def __init__(self, n_in: int, ks=1, sym=False):
112+
def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
113113
super().__init__()
114-
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=False)
114+
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias)
115115
self.gamma = nn.Parameter(torch.tensor([0.]))
116116
self.sym = sym
117117
self.n_in = n_in

model_constructor/model_constructor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(self, name='MC', in_chans=3, num_classes=1000,
152152
if self.sa: # if sa=1 or sa=True
153153
if type(self.sa) in (bool, int):
154154
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
155-
if self.se_module or se_reduction:
155+
if self.se_module or se_reduction: # pragma: no cover
156156
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation worning.
157157

158158
@property

model_constructor/net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, expansion, ni, nh, stride=1,
3636
groups = int(nh / div_groups)
3737
if expansion == 1:
3838
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
39-
groups=nh if dw else groups)),
39+
groups=ni if dw else groups)),
4040
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
4141
]
4242
else:
@@ -78,7 +78,7 @@ def __init__(self, expansion, ni, nh, stride=1,
7878
self.reduce = noop if stride == 1 else pool
7979
if expansion == 1:
8080
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
81-
groups=nh if dw else groups)),
81+
groups=ni if dw else groups)),
8282
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
8383
]
8484
else:

tests/test_Net.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from model_constructor.net import Net, NewResBlock, ResBlock
5+
# from model_constructor.layers import SEModule, SimpleSelfAttention
6+
7+
8+
bs_test = 4
9+
10+
11+
params = dict(
12+
block=[ResBlock, NewResBlock],
13+
expansion=[1, 2],
14+
groups=[1, 2],
15+
dw=[0, 1],
16+
div_groups=[None, 2],
17+
sa=[0, 1],
18+
se=[0, 1],
19+
bn_1st=[True, False],
20+
zero_bn=[True, False],
21+
stem_bn_end=[True, False],
22+
stem_stride_on=[0, 1]
23+
)
24+
25+
26+
def value_name(value) -> str: # pragma: no cover
27+
name = getattr(value, "__name__", None)
28+
if name is not None:
29+
return name
30+
if isinstance(value, nn.Module):
31+
return value._get_name()
32+
else:
33+
return value
34+
35+
36+
def ids_fn(key, value):
37+
return [f"{key[:2]}_{value_name(v)}" for v in value]
38+
39+
40+
def pytest_generate_tests(metafunc):
41+
for key, value in params.items():
42+
if key in metafunc.fixturenames:
43+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
44+
45+
46+
def test_Net(
47+
block, expansion,
48+
groups,
49+
# dw, div_groups,
50+
):
51+
"""test Net"""
52+
c_in = 3
53+
img_size = 16
54+
c_out = 8
55+
name = "Test name"
56+
57+
mc = Net(
58+
name, c_in, c_out, block,
59+
expansion=expansion,
60+
stem_sizes=[8, 16],
61+
block_sizes=[16, 32, 64, 128],
62+
groups=groups,
63+
# dw=dw,
64+
# div_groups=div_groups,
65+
# bn_1st=bn_1st, zero_bn=zero_bn,
66+
# stem_bn_end=stem_bn_end,
67+
)
68+
assert f"{name} constructor" in str(mc)
69+
model = mc()
70+
xb = torch.randn(bs_test, c_in, img_size, img_size)
71+
pred = model(xb)
72+
assert pred.shape == torch.Size([bs_test, c_out])
73+
74+
75+
def test_Net_div_gr(
76+
block, expansion,
77+
div_groups,
78+
):
79+
"""test Net"""
80+
c_in = 3
81+
img_size = 16
82+
c_out = 8
83+
name = "Test name"
84+
85+
mc = Net(
86+
name, c_in, c_out, block,
87+
expansion=expansion,
88+
stem_sizes=[8, 16],
89+
block_sizes=[16, 32, 64, 128],
90+
div_groups=div_groups,
91+
)
92+
assert f"{name} constructor" in str(mc)
93+
model = mc()
94+
xb = torch.randn(bs_test, c_in, img_size, img_size)
95+
pred = model(xb)
96+
assert pred.shape == torch.Size([bs_test, c_out])
97+
98+
99+
def test_Net_dw(
100+
block, expansion,
101+
dw
102+
):
103+
"""test Net"""
104+
c_in = 3
105+
img_size = 16
106+
c_out = 8
107+
name = "Test name"
108+
109+
mc = Net(
110+
name, c_in, c_out, block,
111+
expansion=expansion,
112+
stem_sizes=[8, 16],
113+
block_sizes=[16, 32, 64, 128],
114+
dw=dw
115+
)
116+
assert f"{name} constructor" in str(mc)
117+
model = mc()
118+
xb = torch.randn(bs_test, c_in, img_size, img_size)
119+
pred = model(xb)
120+
assert pred.shape == torch.Size([bs_test, c_out])
121+
122+
123+
def test_Net_2(
124+
block, expansion,
125+
bn_1st, zero_bn,
126+
):
127+
"""test Net"""
128+
c_in = 3
129+
img_size = 16
130+
c_out = 8
131+
name = "Test name"
132+
133+
mc = Net(
134+
name, c_in, c_out, block,
135+
expansion=expansion,
136+
stem_sizes=[8, 16],
137+
block_sizes=[16, 32, 64, 128],
138+
bn_1st=bn_1st, zero_bn=zero_bn,
139+
)
140+
assert f"{name} constructor" in str(mc)
141+
model = mc()
142+
xb = torch.randn(bs_test, c_in, img_size, img_size)
143+
pred = model(xb)
144+
assert pred.shape == torch.Size([bs_test, c_out])
145+
146+
147+
def test_Net_stem(
148+
stem_bn_end,
149+
stem_stride_on
150+
):
151+
"""test Net"""
152+
c_in = 3
153+
img_size = 16
154+
c_out = 8
155+
name = "Test name"
156+
157+
mc = Net(
158+
name, c_in, c_out,
159+
stem_sizes=[8, 16],
160+
block_sizes=[16, 32, 64, 128],
161+
stem_bn_end=stem_bn_end,
162+
stem_stride_on=stem_stride_on
163+
)
164+
assert f"{name} constructor" in str(mc)
165+
model = mc()
166+
xb = torch.randn(bs_test, c_in, img_size, img_size)
167+
pred = model(xb)
168+
assert pred.shape == torch.Size([bs_test, c_out])

tests/test_convmixer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from model_constructor.convmixer import ConvMixer, ConvMixerOriginal
5+
6+
bs_test = 4
7+
8+
9+
params = dict(
10+
bn_1st=[True, False],
11+
pre_act=[True, False],
12+
)
13+
14+
15+
def value_name(value) -> str: # pragma: no cover
16+
name = getattr(value, "__name__", None)
17+
if name is not None:
18+
return name
19+
if isinstance(value, nn.Module):
20+
return value._get_name()
21+
else:
22+
return value
23+
24+
25+
def ids_fn(key, value):
26+
return [f"{key[:2]}_{value_name(v)}" for v in value]
27+
28+
29+
def pytest_generate_tests(metafunc):
30+
for key, value in params.items():
31+
if key in metafunc.fixturenames:
32+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
33+
34+
35+
def test_ConvMixer(bn_1st, pre_act):
36+
"""test ConvMixer"""
37+
bs_test = 4
38+
img_size = 16
39+
model = ConvMixer(dim=64, depth=4, bn_1st=bn_1st, pre_act=pre_act)
40+
xb = torch.randn(bs_test, 3, img_size, img_size)
41+
pred = model(xb)
42+
assert pred.shape == torch.Size([bs_test, 1000])
43+
44+
45+
def test_ConvMixerOriginal():
46+
"""test ConvMixerOriginal"""
47+
bs_test = 4
48+
img_size = 16
49+
model = ConvMixerOriginal(dim=64, depth=4)
50+
xb = torch.randn(bs_test, 3, img_size, img_size)
51+
pred = model(xb)
52+
assert pred.shape == torch.Size([bs_test, 1000])

tests/test_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def value_name(value) -> str:
3333
name = getattr(value, "__name__", None)
3434
if name is not None:
3535
return name
36-
if isinstance(value, nn.Module):
36+
if isinstance(value, nn.Module): # pragma: no cover
3737
return value._get_name()
3838
else:
3939
return value
@@ -87,12 +87,12 @@ def test_ConvBnAct(kernel_size, stride, bias, groups, pre_act, bn_layer, bn_1st,
8787
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])
8888

8989

90-
def test_SimpleSelfAttention(sym):
90+
def test_SimpleSelfAttention(sym, use_bias):
9191
"""test SimpleSelfAttention"""
9292
in_channels = 4
9393
kernel_size = 1 # ? can be 3? if so check sym hack.
9494
channel_size = 4
95-
sa = SimpleSelfAttention(in_channels, kernel_size, sym)
95+
sa = SimpleSelfAttention(in_channels, kernel_size, sym, use_bias)
9696
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
9797
out = sa(xb)
9898
assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size])

tests/test_layers_depr.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44

5-
from model_constructor.layers import SEBlock, SEBlockConv
5+
from model_constructor.layers import ConvLayer, SEBlock, SEBlockConv
66

77

88
bs_test = 4
@@ -15,6 +15,16 @@
1515
rd_channels=[None, 2],
1616
rd_max=[False, True],
1717
use_bias=[True, False],
18+
# ConvLayer
19+
nf=[8, 16],
20+
ks=[3, 1],
21+
stride=[1, 2],
22+
act=[True, False],
23+
bn_layer=[True, False],
24+
bn_1st=[True, False],
25+
zero_bn=[False, True],
26+
bias=[False, True],
27+
groups=[1, 2]
1828
)
1929

2030

@@ -23,7 +33,7 @@ def value_name(value) -> str:
2333
if name is not None:
2434
return name
2535
if isinstance(value, nn.Module):
26-
return value._get_name()
36+
return value._get_name() # pragma: no cover
2737
else:
2838
return value
2939

@@ -47,3 +57,20 @@ def test_SE(se_module, reduction, use_bias):
4757
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
4858
out = se(xb)
4959
assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size])
60+
61+
62+
def test_ConvLayer(nf, ks, stride, bn_layer, bn_1st, zero_bn, bias, groups):
63+
"""test ConvLayer"""
64+
ni = 8
65+
channel_size = 4
66+
block = ConvLayer(
67+
ni, nf, ks, stride,
68+
bn_layer=bn_layer, bn_1st=bn_1st, zero_bn=zero_bn,
69+
bias=bias, groups=groups)
70+
xb = torch.randn(bs_test, ni, channel_size, channel_size)
71+
out = block(xb)
72+
# out_ch = nf
73+
out_size = channel_size
74+
if stride == 2:
75+
out_size = channel_size // stride
76+
assert out.shape == torch.Size([bs_test, nf, out_size, out_size])

tests/test_mc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
3+
from model_constructor import ModelConstructor
4+
from model_constructor.layers import SEModule, SimpleSelfAttention
5+
6+
7+
bs_test = 4
8+
9+
10+
def test_MC():
11+
"""test ModelConstructor"""
12+
img_size = 16
13+
mc = ModelConstructor()
14+
assert "MC constructor" in str(mc)
15+
model = mc()
16+
xb = torch.randn(bs_test, 3, img_size, img_size)
17+
pred = model(xb)
18+
assert pred.shape == torch.Size([bs_test, 1000])
19+
num_classes = 10
20+
mc.num_classes = num_classes
21+
mc.se = SEModule
22+
mc.sa = SimpleSelfAttention
23+
mc.stem_bn_end = True
24+
model = mc()
25+
pred = model(xb)
26+
assert pred.shape == torch.Size([bs_test, num_classes])
27+
mc = ModelConstructor(sa=1, se=1)
28+
assert mc.se is SEModule
29+
assert mc.sa is SimpleSelfAttention

0 commit comments

Comments
 (0)