Skip to content

Commit c458410

Browse files
committed
tests models
1 parent 6b3aaaf commit c458410

File tree

5 files changed

+74
-25
lines changed

5 files changed

+74
-25
lines changed

src/model_constructor/model_constructor.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
3+
from typing import Any, Callable, List, Type, TypeVar, Union
44

55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
@@ -236,11 +236,11 @@ class ModelCfg(BaseModel):
236236
stem_sizes: List[int] = [32, 32, 64]
237237
stem_pool: Union[Callable[[], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
238238
stem_bn_end: bool = False
239-
init_cnn: Optional[Callable[[nn.Module], None]] = init_cnn
240-
make_stem: Optional[Callable[[TModelCfg], Union[nn.Module, nn.Sequential]]] = make_stem
241-
make_layer: Optional[Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]]] = make_layer
242-
make_body: Optional[Callable[[TModelCfg], Union[nn.Module, nn.Sequential]]] = make_body
243-
make_head: Optional[Callable[[TModelCfg], Union[nn.Module, nn.Sequential]]] = make_head
239+
init_cnn: Callable[[nn.Module], None] = init_cnn
240+
make_stem: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_stem
241+
make_layer: Callable[[TModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer
242+
make_body: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_body
243+
make_head: Callable[[TModelCfg], Union[nn.Module, nn.Sequential]] = make_head
244244

245245
class Config:
246246
arbitrary_types_allowed = True
@@ -262,17 +262,6 @@ class ModelConstructor(ModelCfg):
262262

263263
@root_validator
264264
def post_init(cls, values): # pylint: disable=E0213
265-
# if values["init_cnn"] is None:
266-
# values["init_cnn"] = init_cnn
267-
# if values["make_stem"] is None:
268-
# values["make_stem"] = make_stem
269-
# if values["make_layer"] is None:
270-
# values["make_layer"] = make_layer
271-
# if values["make_body"] is None:
272-
# values["make_body"] = make_body
273-
# if values["make_head"] is None:
274-
# values["make_head"] = make_head
275-
276265
if values["stem_sizes"][0] != values["in_chans"]:
277266
values["stem_sizes"] = [values["in_chans"]] + values["stem_sizes"]
278267
if values["se"] and isinstance(values["se"], (bool, int)): # if se=1 or se=True
@@ -287,15 +276,15 @@ def post_init(cls, values): # pylint: disable=E0213
287276

288277
@property
289278
def stem(self):
290-
return self.make_stem(self) # type: ignore
279+
return self.make_stem(self) # pylint: disable=too-many-function-args
291280

292281
@property
293282
def head(self):
294-
return self.make_head(self) # type: ignore
283+
return self.make_head(self) # pylint: disable=too-many-function-args
295284

296285
@property
297286
def body(self):
298-
return self.make_body(self) # type: ignore
287+
return self.make_body(self) # pylint: disable=too-many-function-args
299288

300289
@classmethod
301290
def from_cfg(cls, cfg: ModelCfg):
@@ -305,12 +294,12 @@ def __call__(self):
305294
model = nn.Sequential(
306295
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
307296
)
308-
self.init_cnn(model) # type: ignore
297+
self.init_cnn(model) # pylint: disable=too-many-function-args
309298
model.extra_repr = lambda: f"{self.name}"
310299
return model
311300

312301
def __repr__(self):
313-
se_repr = self.se.__name__ if self.se else "False"
302+
se_repr = self.se.__name__ if self.se else "False" # type: ignore
314303
return (
315304
f"{self.name} constructor\n"
316305
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"

src/model_constructor/xresnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import torch.nn as nn
21
from collections import OrderedDict
32

3+
import torch.nn as nn
4+
45
from .base_constructor import Net
56
from .layers import ConvLayer, Noop, act
67

7-
88
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
99

1010

src/model_constructor/yaresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class YaResNet34(ModelConstructor):
129129
block: Type[nn.Module] = YaResBlock
130130
expansion: int = 1
131131
layers: List[int] = [3, 4, 6, 3]
132-
act_fn: nn.Module = Mish()
132+
act_fn: Type[nn.Module] = Mish
133133

134134

135135
class YaResNet50(YaResNet34):

tests/test_models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
from model_constructor.model_constructor import XResNet34, XResNet50
6+
from model_constructor.yaresnet import YaResNet34, YaResNet50
7+
8+
bs_test = 2
9+
img_size = 16
10+
xb = torch.rand(bs_test, 3, img_size, img_size)
11+
12+
mc_list = [XResNet34, XResNet50, YaResNet34, YaResNet50]
13+
act_fn_list = [nn.ReLU, nn.Mish, nn.GELU]
14+
15+
16+
@pytest.mark.parametrize("model_constructor", mc_list)
17+
@pytest.mark.parametrize("act_fn", act_fn_list)
18+
def test_mc(model_constructor, act_fn):
19+
"""test models"""
20+
mc = model_constructor(act_fn=act_fn)
21+
# assert "name='MC'" in str()
22+
model = mc()
23+
pred = model(xb)
24+
assert pred.shape == torch.Size([bs_test, 1000])

tests/test_models_old.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
from model_constructor.xresnet import xresnet18, xresnet34, xresnet50
6+
from model_constructor.mxresnet import mxresnet34, mxresnet50
7+
8+
bs_test = 2
9+
img_size = 16
10+
xb = torch.rand(bs_test, 3, img_size, img_size)
11+
12+
num_classes = 10
13+
14+
models_list = [xresnet18, xresnet34, xresnet50]
15+
act_fn_list = [nn.ReLU, nn.Mish, nn.GELU]
16+
mx_list = [mxresnet34, mxresnet50]
17+
18+
19+
@pytest.mark.parametrize("model", models_list)
20+
def test_model(model):
21+
"""test models"""
22+
mod = model(num_classes=num_classes)
23+
pred = mod(xb)
24+
assert pred.shape == torch.Size([bs_test, num_classes])
25+
26+
27+
@pytest.mark.parametrize("model_constructor", mx_list)
28+
@pytest.mark.parametrize("act_fn", act_fn_list)
29+
def test_model_mx(model_constructor, act_fn):
30+
"""test models"""
31+
mc = model_constructor(c_out=num_classes)
32+
assert mc.c_out == num_classes
33+
mc.act_fn = act_fn()
34+
model = mc()
35+
pred = model(xb)
36+
assert pred.shape == torch.Size([bs_test, num_classes])

0 commit comments

Comments
 (0)