Skip to content

Commit 246a185

Browse files
committed
TESTS FIX
1 parent e5a1f81 commit 246a185

File tree

7 files changed

+104
-68
lines changed

7 files changed

+104
-68
lines changed

src/model_constructor/yaresnet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
from .model_constructor import ModelConstructor
1212

1313
__all__ = [
14-
'YaResBlock',
14+
"YaResBlock",
15+
"YaResNet34",
16+
"YaResNet50",
1517
]
1618

1719

18-
# act_fn = nn.ReLU(inplace=True)
19-
20-
2120
class YaResBlock(nn.Module):
2221
'''YaResBlock. Reduce by pool instead of stride 2'''
2322

tests/test_Net.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33

44
from model_constructor.net import Net, NewResBlock, ResBlock
5+
56
# from model_constructor.layers import SEModule, SimpleSelfAttention
67

78

@@ -19,7 +20,7 @@
1920
bn_1st=[True, False],
2021
zero_bn=[True, False],
2122
stem_bn_end=[True, False],
22-
stem_stride_on=[0, 1]
23+
stem_stride_on=[0, 1],
2324
)
2425

2526

@@ -28,9 +29,8 @@ def value_name(value) -> str: # pragma: no cover
2829
if name is not None:
2930
return name
3031
if isinstance(value, nn.Module):
31-
return value._get_name()
32-
else:
33-
return value
32+
return value._get_name() # pylint: disable=W0212
33+
return value
3434

3535

3636
def ids_fn(key, value):
@@ -44,7 +44,8 @@ def pytest_generate_tests(metafunc):
4444

4545

4646
def test_Net(
47-
block, expansion,
47+
block,
48+
expansion,
4849
groups,
4950
):
5051
"""test Net"""
@@ -54,15 +55,14 @@ def test_Net(
5455
name = "Test name"
5556

5657
mc = Net(
57-
name, c_in, c_out, block,
58+
name,
59+
c_in,
60+
c_out,
61+
block,
5862
expansion=expansion,
5963
stem_sizes=[8, 16],
6064
block_sizes=[16, 32, 64, 128],
6165
groups=groups,
62-
# dw=dw,
63-
# div_groups=div_groups,
64-
# bn_1st=bn_1st, zero_bn=zero_bn,
65-
# stem_bn_end=stem_bn_end,
6666
)
6767
assert f"{name} constructor" in str(mc)
6868
model = mc()
@@ -71,22 +71,23 @@ def test_Net(
7171
assert pred.shape == torch.Size([bs_test, c_out])
7272

7373

74-
def test_Net_SE_SA(
75-
block, expansion,
76-
se, sa
77-
):
74+
def test_Net_SE_SA(block, expansion, se, sa):
7875
"""test Net"""
7976
c_in = 3
8077
img_size = 16
8178
c_out = 8
8279
name = "Test name"
8380

8481
mc = Net(
85-
name, c_in, c_out, block,
82+
name,
83+
c_in,
84+
c_out,
85+
block,
8686
expansion=expansion,
8787
stem_sizes=[8, 16],
8888
block_sizes=[16, 32, 64, 128],
89-
se=se, sa=sa
89+
se=se,
90+
sa=sa,
9091
)
9192
assert f"{name} constructor" in str(mc)
9293
model = mc()
@@ -96,7 +97,8 @@ def test_Net_SE_SA(
9697

9798

9899
def test_Net_div_gr(
99-
block, expansion,
100+
block,
101+
expansion,
100102
div_groups,
101103
):
102104
"""test Net"""
@@ -106,7 +108,10 @@ def test_Net_div_gr(
106108
name = "Test name"
107109

108110
mc = Net(
109-
name, c_in, c_out, block,
111+
name,
112+
c_in,
113+
c_out,
114+
block,
110115
expansion=expansion,
111116
stem_sizes=[8, 16],
112117
block_sizes=[16, 32, 64, 128],
@@ -119,22 +124,22 @@ def test_Net_div_gr(
119124
assert pred.shape == torch.Size([bs_test, c_out])
120125

121126

122-
def test_Net_dw(
123-
block, expansion,
124-
dw
125-
):
127+
def test_Net_dw(block, expansion, dw):
126128
"""test Net"""
127129
c_in = 3
128130
img_size = 16
129131
c_out = 8
130132
name = "Test name"
131133

132134
mc = Net(
133-
name, c_in, c_out, block,
135+
name,
136+
c_in,
137+
c_out,
138+
block,
134139
expansion=expansion,
135140
stem_sizes=[8, 16],
136141
block_sizes=[16, 32, 64, 128],
137-
dw=dw
142+
dw=dw,
138143
)
139144
assert f"{name} constructor" in str(mc)
140145
model = mc()
@@ -144,8 +149,10 @@ def test_Net_dw(
144149

145150

146151
def test_Net_2(
147-
block, expansion,
148-
bn_1st, zero_bn,
152+
block,
153+
expansion,
154+
bn_1st,
155+
zero_bn,
149156
):
150157
"""test Net"""
151158
c_in = 3
@@ -154,11 +161,15 @@ def test_Net_2(
154161
name = "Test name"
155162

156163
mc = Net(
157-
name, c_in, c_out, block,
164+
name,
165+
c_in,
166+
c_out,
167+
block,
158168
expansion=expansion,
159169
stem_sizes=[8, 16],
160170
block_sizes=[16, 32, 64, 128],
161-
bn_1st=bn_1st, zero_bn=zero_bn,
171+
bn_1st=bn_1st,
172+
zero_bn=zero_bn,
162173
)
163174
assert f"{name} constructor" in str(mc)
164175
model = mc()
@@ -167,25 +178,24 @@ def test_Net_2(
167178
assert pred.shape == torch.Size([bs_test, c_out])
168179

169180

170-
def test_Net_stem(
171-
stem_bn_end,
172-
stem_stride_on
173-
):
181+
def test_Net_stem(stem_bn_end, stem_stride_on):
174182
"""test Net"""
175183
c_in = 3
176184
img_size = 16
177185
c_out = 8
178186
name = "Test name"
179187

180188
mc = Net(
181-
name, c_in, c_out,
189+
name,
190+
c_in,
191+
c_out,
182192
stem_sizes=[8, 16],
183193
block_sizes=[16, 32, 64, 128],
184194
stem_bn_end=stem_bn_end,
185-
stem_stride_on=stem_stride_on
195+
stem_stride_on=stem_stride_on,
186196
)
187197
assert f"{name} constructor" in str(mc)
188198
model = mc()
189199
xb = torch.randn(bs_test, c_in, img_size, img_size)
190200
pred = model(xb)
191-
assert pred.shape == torch.Size([bs_test, c_out])
201+
assert pred.shape == torch.Size([bs_test, c_out])

tests/test_block.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def value_name(value) -> str:
2929
return name
3030
if isinstance(value, nn.Module):
3131
return value._get_name() # pylint: disable=W0212
32-
else:
33-
return value
32+
return value
3433

3534

3635
def ids_fn(key, value):
@@ -48,9 +47,15 @@ def test_block(Block, expansion, mid_channels, stride, div_groups, pool, se, sa)
4847
in_channels = 8
4948
out_channels = mid_channels * expansion
5049
block = Block(
51-
expansion, in_channels, mid_channels,
52-
stride, div_groups=div_groups,
53-
pool=pool, se=se, sa=sa)
50+
expansion,
51+
in_channels,
52+
mid_channels,
53+
stride,
54+
div_groups=div_groups,
55+
pool=pool,
56+
se=se,
57+
sa=sa,
58+
)
5459
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
5560
y = block(xb)
5661
out_size = img_size if stride == 1 else img_size // stride

tests/test_convmixer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from model_constructor.convmixer import ConvMixer, ConvMixerOriginal
55

66
bs_test = 4
7+
img_size = 16
78

89

910
params = dict(
@@ -17,9 +18,8 @@ def value_name(value) -> str: # pragma: no cover
1718
if name is not None:
1819
return name
1920
if isinstance(value, nn.Module):
20-
return value._get_name()
21-
else:
22-
return value
21+
return value._get_name() # pylint: disable=W0212
22+
return value
2323

2424

2525
def ids_fn(key, value):
@@ -34,8 +34,6 @@ def pytest_generate_tests(metafunc):
3434

3535
def test_ConvMixer(bn_1st, pre_act):
3636
"""test ConvMixer"""
37-
bs_test = 4
38-
img_size = 16
3937
model = ConvMixer(dim=64, depth=4, bn_1st=bn_1st, pre_act=pre_act)
4038
xb = torch.randn(bs_test, 3, img_size, img_size)
4139
pred = model(xb)
@@ -44,8 +42,6 @@ def test_ConvMixer(bn_1st, pre_act):
4442

4543
def test_ConvMixerOriginal():
4644
"""test ConvMixerOriginal"""
47-
bs_test = 4
48-
img_size = 16
4945
model = ConvMixerOriginal(dim=64, depth=4)
5046
xb = torch.randn(bs_test, 3, img_size, img_size)
5147
pred = model(xb)

tests/test_layers.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
import torch
22
import torch.nn as nn
33

4-
from model_constructor.layers import ConvBnAct, Flatten, Noop, SEModule, SEModuleConv, SimpleSelfAttention, noop
4+
from model_constructor.layers import (
5+
ConvBnAct,
6+
Flatten,
7+
Noop,
8+
SEModule,
9+
SEModuleConv,
10+
SimpleSelfAttention,
11+
noop,
12+
)
513

614

715
bs_test = 4
@@ -34,9 +42,8 @@ def value_name(value) -> str:
3442
if name is not None:
3543
return name
3644
if isinstance(value, nn.Module): # pragma: no cover
37-
return value._get_name()
38-
else:
39-
return value
45+
return value._get_name() # pylint: disable=W0212
46+
return value
4047

4148

4249
def ids_fn(key, value):
@@ -71,14 +78,25 @@ def test_noop():
7178
assert all(out.eq(xb_copy))
7279

7380

74-
def test_ConvBnAct(kernel_size, stride, bias, groups, pre_act, bn_layer, bn_1st, zero_bn):
81+
def test_ConvBnAct(
82+
kernel_size, stride, bias, groups, pre_act, bn_layer, bn_1st, zero_bn
83+
):
7584
"""test ConvBnAct"""
7685
in_channels = out_channels = 4
7786
channel_size = 4
7887
block = ConvBnAct(
79-
in_channels, out_channels, kernel_size, stride,
80-
padding=None, bias=bias, groups=groups,
81-
pre_act=pre_act, bn_layer=bn_layer, bn_1st=bn_1st, zero_bn=zero_bn)
88+
in_channels,
89+
out_channels,
90+
kernel_size,
91+
stride,
92+
padding=None,
93+
bias=bias,
94+
groups=groups,
95+
pre_act=pre_act,
96+
bn_layer=bn_layer,
97+
bn_1st=bn_1st,
98+
zero_bn=zero_bn,
99+
)
82100
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
83101
out = block(xb)
84102
out_size = channel_size

tests/test_layers_depr.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,17 @@
2424
bn_1st=[True, False],
2525
zero_bn=[False, True],
2626
bias=[False, True],
27-
groups=[1, 2]
27+
groups=[1, 2],
2828
)
2929

3030

31-
def value_name(value) -> str:
31+
def value_name(value) -> str: # pragma: no cover
3232
name = getattr(value, "__name__", None)
3333
if name is not None:
3434
return name
3535
if isinstance(value, nn.Module):
36-
return value._get_name() # pragma: no cover
37-
else:
38-
return value
36+
return value._get_name() # pylint: disable=W0212
37+
return value
3938

4039

4140
def ids_fn(key, value):
@@ -64,9 +63,16 @@ def test_ConvLayer(nf, ks, stride, bn_layer, bn_1st, zero_bn, bias, groups):
6463
ni = 8
6564
channel_size = 4
6665
block = ConvLayer(
67-
ni, nf, ks, stride,
68-
bn_layer=bn_layer, bn_1st=bn_1st, zero_bn=zero_bn,
69-
bias=bias, groups=groups)
66+
ni,
67+
nf,
68+
ks,
69+
stride,
70+
bn_layer=bn_layer,
71+
bn_1st=bn_1st,
72+
zero_bn=zero_bn,
73+
bias=bias,
74+
groups=groups,
75+
)
7076
xb = torch.randn(bs_test, ni, channel_size, channel_size)
7177
out = block(xb)
7278
# out_ch = nf

0 commit comments

Comments
 (0)