Skip to content

Commit db440bc

Browse files
committed
tests for constructors - Mc & Net
1 parent 883d2df commit db440bc

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

tests/test_Net.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc):
4646
def test_Net(
4747
block, expansion,
4848
groups,
49-
# dw, div_groups,
5049
):
5150
"""test Net"""
5251
c_in = 3
@@ -72,6 +71,30 @@ def test_Net(
7271
assert pred.shape == torch.Size([bs_test, c_out])
7372

7473

74+
def test_Net_SE_SA(
75+
block, expansion,
76+
se, sa
77+
):
78+
"""test Net"""
79+
c_in = 3
80+
img_size = 16
81+
c_out = 8
82+
name = "Test name"
83+
84+
mc = Net(
85+
name, c_in, c_out, block,
86+
expansion=expansion,
87+
stem_sizes=[8, 16],
88+
block_sizes=[16, 32, 64, 128],
89+
se=se, sa=sa
90+
)
91+
assert f"{name} constructor" in str(mc)
92+
model = mc()
93+
xb = torch.randn(bs_test, c_in, img_size, img_size)
94+
pred = model(xb)
95+
assert pred.shape == torch.Size([bs_test, c_out])
96+
97+
7598
def test_Net_div_gr(
7699
block, expansion,
77100
div_groups,

tests/test_mc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def test_MC():
2424
model = mc()
2525
pred = model(xb)
2626
assert pred.shape == torch.Size([bs_test, num_classes])
27-
mc = ModelConstructor(sa=1, se=1)
27+
mc = ModelConstructor(sa=1, se=1, num_classes=num_classes)
2828
assert mc.se is SEModule
2929
assert mc.sa is SimpleSelfAttention
30+
model = mc()
31+
pred = model(xb)
32+
assert pred.shape == torch.Size([bs_test, num_classes])

0 commit comments

Comments
 (0)