Skip to content

Commit 76e3257

Browse files
committed
tests layers old
1 parent 851426f commit 76e3257

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

model_constructor/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class SEBlock(nn.Module): # todo: deprecation worning.
141141

142142
def __init__(self, c, r=16):
143143
super().__init__()
144-
ch = c // r
144+
ch = max(c // r, 1)
145145
self.squeeze = nn.AdaptiveAvgPool2d(1)
146146
self.excitation = nn.Sequential(
147147
OrderedDict([('fc_reduce', self.se_layer(c, ch, bias=self.use_bias)),
@@ -166,7 +166,7 @@ class SEBlockConv(nn.Module): # todo: deprecation worning.
166166
def __init__(self, c, r=16):
167167
super().__init__()
168168
# c_in = math.ceil(c//r/8)*8
169-
c_in = c // r
169+
c_in = max(c // r, 1)
170170
self.squeeze = nn.AdaptiveAvgPool2d(1)
171171
self.excitation = nn.Sequential(
172172
OrderedDict([

tests/test_layers_depr.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# old (deprecated layers)
2+
import torch
3+
import torch.nn as nn
4+
5+
from model_constructor.layers import SEBlock, SEBlockConv
6+
7+
8+
bs_test = 4
9+
10+
11+
params = dict(
12+
# SE
13+
se_module=[SEBlock, SEBlockConv],
14+
reduction=[16, 2],
15+
rd_channels=[None, 2],
16+
rd_max=[False, True],
17+
use_bias=[True, False],
18+
)
19+
20+
21+
def value_name(value) -> str:
22+
name = getattr(value, "__name__", None)
23+
if name is not None:
24+
return name
25+
if isinstance(value, nn.Module):
26+
return value._get_name()
27+
else:
28+
return value
29+
30+
31+
def ids_fn(key, value):
32+
return [f"{key[:2]}_{value_name(v)}" for v in value]
33+
34+
35+
def pytest_generate_tests(metafunc):
36+
for key, value in params.items():
37+
if key in metafunc.fixturenames:
38+
metafunc.parametrize(key, value, ids=ids_fn(key, value))
39+
40+
41+
def test_SE(se_module, reduction, use_bias):
42+
"""test SE"""
43+
in_channels = 8
44+
channel_size = 4
45+
se = se_module(in_channels, reduction)
46+
se.use_bias = use_bias
47+
xb = torch.randn(bs_test, in_channels, channel_size, channel_size)
48+
out = se(xb)
49+
assert out.shape == torch.Size([bs_test, in_channels, channel_size, channel_size])

0 commit comments

Comments
 (0)