Skip to content

Commit f2aa281

Browse files
committed
move xresnet from model_constructor to xresnet module
1 parent 8f4b92a commit f2aa281

File tree

3 files changed

+82
-21
lines changed

3 files changed

+82
-21
lines changed

src/model_constructor/model_constructor.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def init_cnn(module: nn.Module) -> None:
3434

3535

3636
class BasicBlock(nn.Module):
37-
"""Basic Resnet block."""
37+
"""Basic Resnet block.
38+
Configurable - can use pool to reduce at identity path, change act etc. """
3839

3940
def __init__(
4041
self,
@@ -117,7 +118,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
117118

118119

119120
class BottleneckBlock(nn.Module):
120-
"""Bottleneck Resnet block."""
121+
"""Bottleneck Resnet block.
122+
Configurable - can use pool to reduce at identity path, change act etc. """
121123

122124
def __init__(
123125
self,
@@ -211,21 +213,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
211213

212214

213215
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
214-
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
215-
len_stem = len(cfg.stem_sizes)
216+
"""Create Resnet stem."""
216217
stem: ListStrMod = [
217218
(
218-
f"conv_{i}",
219+
"conv_1",
219220
cfg.conv_layer(
220-
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
221-
cfg.stem_sizes[i],
222-
stride=2 if i == cfg.stem_stride_on else 1,
223-
bn_layer=(not cfg.stem_bn_end) if i == (len_stem - 1) else True,
221+
cfg.in_chans, # type: ignore
222+
cfg.stem_sizes[-1],
223+
kernel_size=7,
224+
stride=2,
225+
padding=3,
226+
bn_layer=not cfg.stem_bn_end,
224227
act_fn=cfg.act_fn,
225228
bn_1st=cfg.bn_1st,
226229
),
227230
)
228-
for i in range(len_stem)
229231
]
230232
if cfg.stem_pool:
231233
stem.append(("stem_pool", cfg.stem_pool()))
@@ -295,9 +297,7 @@ class ModelCfg(BaseModel):
295297
layers: list[int] = [2, 2, 2, 2]
296298
norm: type[nn.Module] = nn.BatchNorm2d
297299
act_fn: type[nn.Module] = nn.ReLU
298-
pool: Callable[[Any], nn.Module] = partial(
299-
nn.AvgPool2d, kernel_size=2, ceil_mode=True
300-
)
300+
pool: Optional[Callable[[Any], nn.Module]] = None
301301
expansion: int = 1
302302
groups: int = 1
303303
dw: bool = False
@@ -309,7 +309,7 @@ class ModelCfg(BaseModel):
309309
bn_1st: bool = True
310310
zero_bn: bool = True
311311
stem_stride_on: int = 0
312-
stem_sizes: list[int] = [32, 32, 64]
312+
stem_sizes: list[int] = [64]
313313
stem_pool: Union[Callable[[], nn.Module], None] = partial(
314314
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
315315
)
@@ -368,7 +368,7 @@ def print_changed(self) -> None:
368368

369369

370370
class ModelConstructor(ModelCfg):
371-
"""Model constructor. As default - xresnet18"""
371+
"""Model constructor. As default - resnet18"""
372372

373373
@validator("se")
374374
def set_se( # pylint: disable=no-self-argument
@@ -446,10 +446,10 @@ def __repr__(self) -> str:
446446
)
447447

448448

449-
class XResNet34(ModelConstructor):
449+
class ResNet34(ModelConstructor):
450450
layers: list[int] = [3, 4, 6, 3]
451451

452452

453-
class XResNet50(XResNet34):
453+
class ResNet50(ResNet34):
454454
block: type[nn.Module] = BottleneckBlock
455455
block_sizes: list[int] = [256, 512, 1024, 2048]

src/model_constructor/xresnet.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from functools import partial
2+
from typing import Any, Callable, Optional, Union
3+
4+
from torch import nn
5+
6+
from .helpers import nn_seq
7+
from .model_constructor import (BottleneckBlock, ListStrMod, ModelCfg,
8+
ModelConstructor)
9+
10+
__all__ = [
11+
"XResNet",
12+
"XResNet34",
13+
"XResNet50",
14+
]
15+
16+
17+
def xresnet_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
18+
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
19+
len_stem = len(cfg.stem_sizes)
20+
stem: ListStrMod = [
21+
(
22+
f"conv_{i}",
23+
cfg.conv_layer(
24+
cfg.stem_sizes[i - 1] if i else cfg.in_chans, # type: ignore
25+
cfg.stem_sizes[i],
26+
stride=2 if i == cfg.stem_stride_on else 1,
27+
bn_layer=(not cfg.stem_bn_end) if i == (len_stem - 1) else True,
28+
act_fn=cfg.act_fn,
29+
bn_1st=cfg.bn_1st,
30+
),
31+
)
32+
for i in range(len_stem)
33+
]
34+
if cfg.stem_pool:
35+
stem.append(("stem_pool", cfg.stem_pool()))
36+
if cfg.stem_bn_end:
37+
stem.append(("norm", cfg.norm(cfg.stem_sizes[-1]))) # type: ignore
38+
return nn_seq(stem)
39+
40+
41+
class XResNet(ModelConstructor):
42+
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = xresnet_stem
43+
stem_sizes: list[int] = [32, 32, 64]
44+
pool: Optional[Callable[[Any], nn.Module]] = partial(
45+
nn.AvgPool2d, kernel_size=2, ceil_mode=True
46+
)
47+
48+
49+
class XResNet34(XResNet):
50+
layers: list[int] = [3, 4, 6, 3]
51+
52+
53+
class XResNet50(XResNet34):
54+
block: type[nn.Module] = BottleneckBlock
55+
block_sizes: list[int] = [256, 512, 1024, 2048]

src/model_constructor/yaresnet.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# YaResBlock - former NewResBlock.
22
# Yet another ResNet.
33

4-
from collections import OrderedDict
5-
from typing import Callable, Union
4+
from functools import partial
5+
from typing import Any, Callable, Optional, Union
66

77
import torch
88
from torch import nn
@@ -11,7 +11,9 @@
1111
from model_constructor.helpers import nn_seq
1212

1313
from .layers import ConvBnAct, get_act
14-
from .model_constructor import ListStrMod, ModelConstructor
14+
from .model_constructor import ListStrMod, ModelConstructor, ModelCfg
15+
from .xresnet import xresnet_stem
16+
1517

1618
__all__ = [
1719
"YaBasicBlock",
@@ -202,9 +204,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
202204

203205

204206
class YaResNet(ModelConstructor):
205-
block: type[nn.Module] = YaBasicBlock
207+
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = xresnet_stem
206208
stem_sizes: list[int] = [3, 32, 64, 64]
209+
block: type[nn.Module] = YaBasicBlock
207210
act_fn: type[nn.Module] = Mish
211+
pool: Optional[Callable[[Any], nn.Module]] = partial(
212+
nn.AvgPool2d, kernel_size=2, ceil_mode=True
213+
)
208214

209215

210216
class YaResNet34(YaResNet):

0 commit comments

Comments
 (0)