Skip to content

Commit eb20472

Browse files
committed
fix models, tests, typing
1 parent 756ceca commit eb20472

File tree

2 files changed

+67
-33
lines changed

2 files changed

+67
-33
lines changed

src/model_constructor/model_constructor.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# pylance: disable=overridden method
21
from collections import OrderedDict
32
from functools import partial
43
from typing import Any, Callable, Optional, TypeVar, Union
@@ -39,7 +38,6 @@ class BasicBlock(nn.Module):
3938

4039
def __init__(
4140
self,
42-
# expansion: int,
4341
in_channels: int,
4442
out_channels: int,
4543
stride: int = 1,
@@ -56,7 +54,6 @@ def __init__(
5654
):
5755
super().__init__()
5856
# pool defined at ModelConstructor.
59-
# out_channels, in_channels = mid_channels * expansion, in_channels * expansion
6057
if div_groups is not None: # check if groups != 1 and div_groups
6158
groups = int(out_channels / div_groups)
6259
layers: ListStrMod = [
@@ -66,7 +63,7 @@ def __init__(
6663
in_channels,
6764
out_channels,
6865
3,
69-
stride=stride, # type: ignore
66+
stride=stride,
7067
act_fn=act_fn,
7168
bn_1st=bn_1st,
7269
groups=in_channels if dw else groups,
@@ -114,7 +111,7 @@ def __init__(
114111
self.id_conv = None
115112
self.act_fn = get_act(act_fn)
116113

117-
def forward(self, x: torch.Tensor) -> torch.Tensor:
114+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
118115
identity = self.id_conv(x) if self.id_conv is not None else x
119116
return self.act_fn(self.convs(x) + identity)
120117

@@ -177,7 +174,7 @@ def __init__(
177174
act_fn=False,
178175
bn_1st=bn_1st,
179176
),
180-
), # noqa E501
177+
),
181178
]
182179
if se:
183180
layers.append(("se", se(out_channels)))
@@ -208,7 +205,7 @@ def __init__(
208205
self.id_conv = None
209206
self.act_fn = get_act(act_fn)
210207

211-
def forward(self, x: torch.Tensor) -> torch.Tensor:
208+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
212209
identity = self.id_conv(x) if self.id_conv is not None else x
213210
return self.act_fn(self.convs(x) + identity)
214211

@@ -234,7 +231,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
234231
stem.append(("stem_pool", cfg.stem_pool()))
235232
if cfg.stem_bn_end:
236233
stem.append(("norm", cfg.norm(cfg.stem_sizes[-1]))) # type: ignore
237-
return nn.Sequential(OrderedDict(stem))
234+
return nn_seq(stem)
238235

239236

240237
def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
@@ -247,15 +244,12 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
247244
(
248245
f"bl_{block_num}",
249246
cfg.block(
250-
# cfg.expansion, # type: ignore
251-
block_chs[layer_num]
247+
block_chs[layer_num] # type: ignore
252248
if block_num == 0
253249
else block_chs[layer_num + 1],
254250
block_chs[layer_num + 1],
255251
stride if block_num == 0 else 1,
256-
sa=cfg.sa
257-
if (block_num == num_blocks - 1) and layer_num == 0
258-
else None,
252+
sa=cfg.sa if (block_num == num_blocks - 1) and layer_num == 0 else None,
259253
conv_layer=cfg.conv_layer,
260254
act_fn=cfg.act_fn,
261255
pool=cfg.pool,
@@ -265,21 +259,17 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
265259
div_groups=cfg.div_groups,
266260
dw=cfg.dw,
267261
se=cfg.se,
268-
)
262+
),
269263
)
270264
for block_num in range(num_blocks)
271265
)
272266

273267

274268
def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
275269
"""Create model body."""
276-
return nn.Sequential(
277-
OrderedDict(
278-
[
279-
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
280-
for layer_num in range(len(cfg.layers))
281-
]
282-
)
270+
return nn_seq(
271+
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
272+
for layer_num in range(len(cfg.layers))
283273
)
284274

285275

@@ -290,7 +280,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
290280
("flat", nn.Flatten()),
291281
("fc", nn.Linear(cfg.block_sizes[-1], cfg.num_classes)),
292282
]
293-
return nn.Sequential(OrderedDict(head))
283+
return nn_seq(head)
294284

295285

296286
class ModelCfg(BaseModel):
@@ -381,25 +371,29 @@ class ModelConstructor(ModelCfg):
381371
"""Model constructor. As default - xresnet18"""
382372

383373
@validator("se")
384-
def set_se(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]:
374+
def set_se( # pylint: disable=no-self-argument
375+
cls, value: Union[bool, type[nn.Module]]
376+
) -> Union[bool, type[nn.Module]]:
385377
if value:
386378
if isinstance(value, (int, bool)):
387379
return SEModule
388380
return value
389381

390382
@validator("sa")
391-
def set_sa(cls, value: Union[bool, type[nn.Module]]) -> Union[bool, type[nn.Module]]:
383+
def set_sa( # pylint: disable=no-self-argument
384+
cls, value: Union[bool, type[nn.Module]]
385+
) -> Union[bool, type[nn.Module]]:
392386
if value:
393387
if isinstance(value, (int, bool)):
394388
return SimpleSelfAttention # default: ks=1, sym=sym
395389
return value
396390

397-
@validator("se_module", "se_reduction")
398-
def deprecation_warning(cls, value): # pragma: no cover
399-
print(
400-
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
401-
)
402-
return value
391+
@validator("se_module", "se_reduction") # pragma: no cover
392+
def deprecation_warning( # pylint: disable=no-self-argument
393+
cls, value: Union[bool, int, None]
394+
) -> Union[bool, int, None]:
395+
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.")
396+
return value
403397

404398
@property
405399
def stem(self):
@@ -420,9 +414,11 @@ def from_cfg(cls, cfg: ModelCfg):
420414
def __call__(self) -> nn.Sequential:
421415
"""Create model."""
422416
model_name = self.name or self.__class__.__name__
423-
named_sequential = type(model_name, (nn.Sequential,), {}) # create type named as model
417+
named_sequential = type(
418+
model_name, (nn.Sequential,), {}
419+
) # create type named as model
424420
model = named_sequential(
425-
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
421+
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)]) # type: ignore
426422
)
427423
self.init_cnn(model) # pylint: disable=too-many-function-args
428424
extra_repr = self.__repr_changed_args__()
@@ -449,4 +445,5 @@ class XResNet34(ModelConstructor):
449445

450446

451447
class XResNet50(XResNet34):
452-
expansion: int = 4
448+
block: type[nn.Module] = BottleneckBlock
449+
block_sizes: list[int] = [256, 512, 1024, 2048]

tests/test_models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
5+
from model_constructor.model_constructor import (
6+
ModelConstructor,
7+
# XResNet,
8+
XResNet34,
9+
XResNet50,
10+
# YaResNet,
11+
# YaResNet34,
12+
# YaResNet50,
13+
)
14+
15+
bs_test = 2
16+
img_size = 16
17+
xb = torch.rand(bs_test, 3, img_size, img_size)
18+
19+
mc_list = [
20+
# XResNet,
21+
XResNet34,
22+
XResNet50,
23+
# YaResNet,
24+
# YaResNet34,
25+
# YaResNet50,
26+
]
27+
act_fn_list = [nn.ReLU, nn.Mish, nn.GELU]
28+
29+
30+
@pytest.mark.parametrize("model_constructor", mc_list)
31+
@pytest.mark.parametrize("act_fn", act_fn_list)
32+
def test_mc(model_constructor: type[ModelConstructor], act_fn: type[nn.Module]):
33+
"""test models"""
34+
mc = model_constructor(act_fn=act_fn)
35+
model = mc()
36+
pred = model(xb)
37+
assert pred.shape == torch.Size([bs_test, 1000])

0 commit comments

Comments
 (0)