Skip to content

Commit 104d32d

Browse files
authored
Merge pull request #63 from ayasyrev/models
Models
2 parents 36bdd35 + 8675071 commit 104d32d

File tree

2 files changed

+117
-34
lines changed

2 files changed

+117
-34
lines changed

src/model_constructor/model_constructor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,14 @@ def print_cfg(self):
310310
)
311311

312312

313-
xresnet34 = ModelConstructor.from_cfg(
314-
CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3])
315-
)
313+
@dataclass
314+
class XResNet34(ModelConstructor):
315+
name: str = "xresnet34"
316+
layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3])
317+
316318

317-
xresnet50 = ModelConstructor.from_cfg(
318-
CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3])
319-
)
319+
@dataclass
320+
class XResNet50(ModelConstructor):
321+
name: str = "xresnet50"
322+
expansion: int = 4
323+
layers: list[int] = field(default_factory=lambda: [3, 4, 6, 3])

src/model_constructor/yaresnet.py

Lines changed: 107 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# YaResBlock - former NewResBlock.
22
# Yet another ResNet.
33

4-
import torch.nn as nn
5-
from functools import partial
64
from collections import OrderedDict
7-
from .layers import ConvBnAct
8-
from .net import Net
5+
from typing import Union
6+
7+
import torch.nn as nn
98
from torch.nn import Mish
109

10+
from .layers import ConvBnAct
11+
from .model_constructor import CfgMC, ModelConstructor
1112

12-
__all__ = ['YaResBlock', 'yaresnet_parameters', 'yaresnet34', 'yaresnet50']
13+
__all__ = [
14+
'YaResBlock',
15+
'yaresnet34',
16+
'yaresnet50',
17+
]
1318

1419

1520
act_fn = nn.ReLU(inplace=True)
@@ -18,16 +23,29 @@
1823
class YaResBlock(nn.Module):
1924
'''YaResBlock. Reduce by pool instead of stride 2'''
2025

21-
def __init__(self, expansion, in_channels, mid_channels, stride=1,
22-
conv_layer=ConvBnAct, act_fn=act_fn, zero_bn=True, bn_1st=True,
23-
groups=1, dw=False, div_groups=None,
24-
pool=None,
25-
se=None, sa=None,
26-
):
26+
def __init__(
27+
self,
28+
expansion: int,
29+
in_channels: int,
30+
mid_channels: int,
31+
stride: int = 1,
32+
conv_layer=ConvBnAct,
33+
act_fn: nn.Module = act_fn,
34+
zero_bn: bool = True,
35+
bn_1st: bool = True,
36+
groups: int = 1,
37+
dw: bool = False,
38+
div_groups: Union[None, int] = None,
39+
pool: Union[nn.Module, None] = None,
40+
se: Union[nn.Module, None] = None,
41+
sa: Union[nn.Module, None] = None,
42+
):
2743
super().__init__()
44+
# pool defined at ModelConstructor.
2845
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
2946
if div_groups is not None: # check if groups != 1 and div_groups
3047
groups = int(mid_channels / div_groups)
48+
3149
if stride != 1:
3250
if pool is None:
3351
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
@@ -36,23 +54,69 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3654
self.reduce = pool
3755
else:
3856
self.reduce = None
39-
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1,
40-
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
41-
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
42-
act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
43-
] if expansion == 1 else [
44-
("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
45-
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
46-
groups=mid_channels if dw else groups)),
47-
("conv_2", conv_layer(
48-
mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
49-
]
57+
if expansion == 1:
58+
layers = [
59+
("conv_0", conv_layer(
60+
in_channels,
61+
mid_channels,
62+
3,
63+
stride=1,
64+
act_fn=act_fn,
65+
bn_1st=bn_1st,
66+
groups=in_channels if dw else groups,
67+
),),
68+
("conv_1", conv_layer(
69+
mid_channels,
70+
out_channels,
71+
3,
72+
zero_bn=zero_bn,
73+
act_fn=False,
74+
bn_1st=bn_1st,
75+
groups=mid_channels if dw else groups,
76+
),),
77+
]
78+
else:
79+
layers = [
80+
("conv_0", conv_layer(
81+
in_channels,
82+
mid_channels,
83+
1,
84+
act_fn=act_fn,
85+
bn_1st=bn_1st,
86+
),),
87+
("conv_1", conv_layer(
88+
mid_channels,
89+
mid_channels,
90+
3,
91+
stride=1,
92+
act_fn=act_fn,
93+
bn_1st=bn_1st,
94+
groups=mid_channels if dw else groups,
95+
),),
96+
("conv_2", conv_layer(
97+
mid_channels,
98+
out_channels,
99+
1,
100+
zero_bn=zero_bn,
101+
act_fn=False,
102+
bn_1st=bn_1st,
103+
),), # noqa E501
104+
]
50105
if se:
51-
layers.append(('se', se(out_channels)))
106+
layers.append(("se", se(out_channels)))
52107
if sa:
53-
layers.append(('sa', sa(out_channels)))
108+
layers.append(("sa", sa(out_channels)))
54109
self.convs = nn.Sequential(OrderedDict(layers))
55-
self.id_conv = None if in_channels == out_channels else conv_layer(in_channels, out_channels, 1, act_fn=False)
110+
if in_channels != out_channels:
111+
self.id_conv = conv_layer(
112+
in_channels,
113+
out_channels,
114+
1,
115+
stride=1,
116+
act_fn=False,
117+
)
118+
else:
119+
self.id_conv = None
56120
self.merge = act_fn
57121

58122
def forward(self, x):
@@ -62,6 +126,21 @@ def forward(self, x):
62126
return self.merge(self.convs(x) + identity)
63127

64128

65-
yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}
66-
yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)
67-
yaresnet50 = partial(Net, name='YaResnet50', expansion=4, layers=[3, 4, 6, 3], **yaresnet_parameters)
129+
yaresnet34 = ModelConstructor.from_cfg(
130+
CfgMC(
131+
name='YaResnet34',
132+
block=YaResBlock,
133+
expansion=1,
134+
layers=[3, 4, 6, 3],
135+
act_fn=Mish(),
136+
)
137+
)
138+
yaresnet50 = ModelConstructor.from_cfg(
139+
CfgMC(
140+
name='YaResnet50',
141+
block=YaResBlock,
142+
act_fn=Mish(),
143+
expansion=4,
144+
layers=[3, 4, 6, 3],
145+
)
146+
)

0 commit comments

Comments
 (0)