Skip to content

Commit f10cad8

Browse files
committed
test blocks, renamed tests uni blocks
1 parent ece3c83 commit f10cad8

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

tests/test_blocks.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# import pytest
2+
from functools import partial
3+
4+
import torch
5+
from torch import nn
6+
7+
from model_constructor.layers import SEModule, SimpleSelfAttention
8+
from model_constructor.model_constructor import BasicBlock, BottleneckBlock
9+
10+
from .parameters import ids_fn
11+
12+
bs_test = 4
13+
img_size = 16
14+
15+
16+
params = dict(
17+
Block=[BasicBlock, BottleneckBlock],
18+
# expansion=[1, 2],
19+
out_channels=[8, 16],
20+
stride=[1, 2],
21+
div_groups=[None, 2],
22+
pool=[None, partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)],
23+
se=[None, SEModule],
24+
sa=[None, SimpleSelfAttention],
25+
)
26+
27+
28+
def pytest_generate_tests(metafunc):
29+
for key, value in params.items():
30+
if key in metafunc.fixturenames:
31+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
32+
33+
34+
def test_block(Block, out_channels, stride, div_groups, pool, se, sa):
35+
"""test block"""
36+
in_channels = 8
37+
# out_channels = mid_channels * expansion
38+
block = Block(
39+
in_channels,
40+
out_channels,
41+
stride,
42+
div_groups=div_groups,
43+
pool=pool,
44+
se=se,
45+
sa=sa,
46+
)
47+
xb = torch.randn(bs_test, in_channels, img_size, img_size)
48+
out = block(xb)
49+
out_size = img_size if stride == 1 else img_size // stride
50+
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])

tests/test_universal_block.py renamed to tests/test_blocks_universal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ def test_block(Block, expansion, mid_channels, stride, div_groups, pool, se, sa)
4646
sa=sa,
4747
)
4848
xb = torch.randn(bs_test, in_channels * expansion, img_size, img_size)
49-
y = block(xb)
49+
out = block(xb)
5050
out_size = img_size if stride == 1 else img_size // stride
51-
assert y.shape == torch.Size([bs_test, out_channels, out_size, out_size])
51+
assert out.shape == torch.Size([bs_test, out_channels, out_size, out_size])

0 commit comments

Comments
 (0)