Skip to content

Commit 9da2a92

Browse files
committed
yaresnet
1 parent 36bdd35 commit 9da2a92

File tree

1 file changed

+94
-28
lines changed

1 file changed

+94
-28
lines changed

src/model_constructor/yaresnet.py

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
# YaResBlock - former NewResBlock.
22
# Yet another ResNet.
33

4-
import torch.nn as nn
5-
from functools import partial
64
from collections import OrderedDict
7-
from .layers import ConvBnAct
8-
from .net import Net
5+
from functools import partial
6+
from typing import Union
7+
8+
import torch.nn as nn
99
from torch.nn import Mish
1010

11+
from .layers import ConvBnAct
12+
from .net import Net
1113

12-
__all__ = ['YaResBlock', 'yaresnet_parameters', 'yaresnet34', 'yaresnet50']
14+
__all__ = [
15+
'YaResBlock',
16+
# 'yaresnet_parameters',
17+
# 'yaresnet34',
18+
# 'yaresnet50',
19+
]
1320

1421

1522
act_fn = nn.ReLU(inplace=True)
@@ -18,16 +25,29 @@
1825
class YaResBlock(nn.Module):
1926
'''YaResBlock. Reduce by pool instead of stride 2'''
2027

21-
def __init__(self, expansion, in_channels, mid_channels, stride=1,
22-
conv_layer=ConvBnAct, act_fn=act_fn, zero_bn=True, bn_1st=True,
23-
groups=1, dw=False, div_groups=None,
24-
pool=None,
25-
se=None, sa=None,
26-
):
28+
def __init__(
29+
self,
30+
expansion: int,
31+
in_channels: int,
32+
mid_channels: int,
33+
stride: int = 1,
34+
conv_layer=ConvBnAct,
35+
act_fn: nn.Module = act_fn,
36+
zero_bn: bool = True,
37+
bn_1st: bool = True,
38+
groups: int = 1,
39+
dw: bool = False,
40+
div_groups: Union[None, int] = None,
41+
pool: Union[nn.Module, None] = None,
42+
se: Union[nn.Module, None] = None,
43+
sa: Union[nn.Module, None] = None,
44+
):
2745
super().__init__()
46+
# pool defined at ModelConstructor.
2847
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
2948
if div_groups is not None: # check if groups != 1 and div_groups
3049
groups = int(mid_channels / div_groups)
50+
3151
if stride != 1:
3252
if pool is None:
3353
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
@@ -36,23 +56,69 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
3656
self.reduce = pool
3757
else:
3858
self.reduce = None
39-
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1,
40-
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
41-
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
42-
act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
43-
] if expansion == 1 else [
44-
("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
45-
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
46-
groups=mid_channels if dw else groups)),
47-
("conv_2", conv_layer(
48-
mid_channels, out_channels, 1, zero_bn=zero_bn, act_fn=False, bn_1st=bn_1st))
49-
]
59+
if expansion == 1:
60+
layers = [
61+
("conv_0", conv_layer(
62+
in_channels,
63+
mid_channels,
64+
3,
65+
stride=1,
66+
act_fn=act_fn,
67+
bn_1st=bn_1st,
68+
groups=in_channels if dw else groups,
69+
),),
70+
("conv_1", conv_layer(
71+
mid_channels,
72+
out_channels,
73+
3,
74+
zero_bn=zero_bn,
75+
act_fn=False,
76+
bn_1st=bn_1st,
77+
groups=mid_channels if dw else groups,
78+
),),
79+
]
80+
else:
81+
layers = [
82+
("conv_0", conv_layer(
83+
in_channels,
84+
mid_channels,
85+
1,
86+
act_fn=act_fn,
87+
bn_1st=bn_1st,
88+
),),
89+
("conv_1", conv_layer(
90+
mid_channels,
91+
mid_channels,
92+
3,
93+
stride=1,
94+
act_fn=act_fn,
95+
bn_1st=bn_1st,
96+
groups=mid_channels if dw else groups,
97+
),),
98+
("conv_2", conv_layer(
99+
mid_channels,
100+
out_channels,
101+
1,
102+
zero_bn=zero_bn,
103+
act_fn=False,
104+
bn_1st=bn_1st,
105+
),), # noqa E501
106+
]
50107
if se:
51-
layers.append(('se', se(out_channels)))
108+
layers.append(("se", se(out_channels)))
52109
if sa:
53-
layers.append(('sa', sa(out_channels)))
110+
layers.append(("sa", sa(out_channels)))
54111
self.convs = nn.Sequential(OrderedDict(layers))
55-
self.id_conv = None if in_channels == out_channels else conv_layer(in_channels, out_channels, 1, act_fn=False)
112+
if in_channels != out_channels:
113+
self.id_conv = conv_layer(
114+
in_channels,
115+
out_channels,
116+
1,
117+
stride=1,
118+
act_fn=False,
119+
)
120+
else:
121+
self.id_conv = None
56122
self.merge = act_fn
57123

58124
def forward(self, x):
@@ -62,6 +128,6 @@ def forward(self, x):
62128
return self.merge(self.convs(x) + identity)
63129

64130

65-
yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}
66-
yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)
67-
yaresnet50 = partial(Net, name='YaResnet50', expansion=4, layers=[3, 4, 6, 3], **yaresnet_parameters)
131+
# yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}
132+
# yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)
133+
# yaresnet50 = partial(Net, name='YaResnet50', expansion=4, layers=[3, 4, 6, 3], **yaresnet_parameters)

0 commit comments

Comments
 (0)