Skip to content

Commit e1b6910

Browse files
committed
tests stem etc
1 parent b998f83 commit e1b6910

File tree

5 files changed

+70
-5
lines changed

5 files changed

+70
-5
lines changed

src/model_constructor/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
id_layers: ListStrMod = []
6767
if (
6868
stride != 1 and pool is not None
69-
): # if pool - reduce by pool else stride 2 art id_conv
69+
): # if pool - reduce by pool else stride 2 at id_conv
7070
id_layers.append(("pool", pool()))
7171
if in_channels != out_channels or (stride != 1 and pool is None):
7272
id_layers.append(

src/model_constructor/model_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def set_modules( # pylint: disable=no-self-argument
6767
return value
6868
if isinstance(value, str):
6969
return instantiate_module(value)
70-
raise ValueError(f"{info.field_name} must be str or nn.Module")
70+
# raise ValueError(f"{info.field_name} must be str or nn.Module")
7171

7272
@field_validator("se", "sa")
7373
def set_se( # pylint: disable=no-self-argument
@@ -77,7 +77,7 @@ def set_se( # pylint: disable=no-self-argument
7777
return DEFAULT_SE_SA[info.field_name]
7878
if is_module(value):
7979
return value
80-
raise ValueError(f"{info.field_name} must be bool or nn.Module")
80+
# raise ValueError(f"{info.field_name} must be bool or nn.Module") # no need - check at init
8181

8282
@field_validator("se_module", "se_reduction") # pragma: no cover
8383
def deprecation_warning( # pylint: disable=no-self-argument

tests/test_mc.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5-
from model_constructor.blocks import BottleneckBlock
5+
from model_constructor.blocks import BasicBlock, BottleneckBlock
66
from model_constructor.layers import SEModule, SEModuleConv, SimpleSelfAttention
7-
from model_constructor.model_constructor import ModelConstructor
7+
from model_constructor.model_constructor import ModelCfg, ModelConstructor
88

99
bs_test = 4
1010
in_chans = 3
@@ -94,3 +94,46 @@ def test_MC_bottleneck():
9494
assert model.body.l_0.bl_0.convs.conv_0.conv.in_channels == 64
9595
assert model.body.l_0.bl_0.convs.conv_0.conv.out_channels == 128
9696
assert model.body.l_0.bl_1.convs.conv_0.conv.in_channels == 256
97+
98+
99+
def test_ModelCfg():
100+
"""test ModelCfg"""
101+
# default - just create config with custom name
102+
cfg = ModelCfg(name="custom_name")
103+
repr_str = cfg.__repr__()
104+
assert repr_str.startswith("custom_name")
105+
# initiate from string
106+
cfg = ModelCfg(act_fn="torch.nn.Mish")
107+
assert cfg.act_fn is torch.nn.Mish
108+
# wrong name
109+
try:
110+
cfg = ModelCfg(act_fn="wrong_name")
111+
except ImportError as err:
112+
assert str(err) == "Module wrong_name not found at torch.nn"
113+
cfg = ModelCfg(act_fn="nn.Tanh")
114+
assert cfg.act_fn is torch.nn.Tanh
115+
cfg = ModelCfg(block="model_constructor.blocks.BottleneckBlock")
116+
assert cfg.block is BottleneckBlock
117+
118+
119+
def test_create_model_class_methods():
120+
"""test class methods ModelConstructor"""
121+
# create model
122+
model = ModelConstructor.create_model(act_fn="Mish", num_classes=10)
123+
assert str(model.body.l_0.bl_0.convs.conv_0.act_fn) == "Mish(inplace=True)"
124+
pred = model(xb)
125+
assert pred.shape == torch.Size([bs_test, 10])
126+
# from cfg
127+
cfg = ModelCfg(block=BottleneckBlock, num_classes=10)
128+
mc = ModelConstructor.from_cfg(cfg)
129+
model = mc()
130+
assert isinstance(model.body.l_0.bl_0, BottleneckBlock)
131+
pred = model(xb)
132+
assert pred.shape == torch.Size([bs_test, 10])
133+
134+
cfg.block = BasicBlock
135+
cfg.num_classes = 2
136+
model = ModelConstructor.create_model(cfg)
137+
assert isinstance(model.body.l_0.bl_0, BasicBlock)
138+
pred = model(xb)
139+
assert pred.shape == torch.Size([bs_test, 2])

tests/test_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
4040
model = mc()
4141
pred = model(xb)
4242
assert pred.shape == torch.Size([bs_test, 1000])
43+
44+
45+
def test_xresnet_stem():
46+
"""test xresnet stem"""
47+
mc = XResNet()
48+
assert mc.stem_bn_end == False
49+
mc.stem_bn_end = True
50+
stem = mc.stem
51+
assert isinstance(stem[-1], nn.BatchNorm2d)
52+
stem_out = stem(xb)
53+
assert stem_out.shape == torch.Size([bs_test, 64, 4, 4])

tests/test_models_universal_blocks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,14 @@ def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
3535
model = mc()
3636
pred = model(xb)
3737
assert pred.shape == torch.Size([bs_test, 1000])
38+
39+
40+
def test_stem_bnend():
41+
"""test stem"""
42+
mc = ModelConstructor()
43+
assert mc.stem_bn_end == False
44+
mc.stem_bn_end = True
45+
stem = mc.stem
46+
assert isinstance(stem[-1], nn.BatchNorm2d)
47+
stem_out = stem(xb)
48+
assert stem_out.shape == torch.Size([bs_test, 64, 4, 4])

0 commit comments

Comments
 (0)