Skip to content

Commit 69e2479

Browse files
committed
resblock, yaresblock
1 parent de7a7e0 commit 69e2479

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

model_constructor/model_constructor.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,35 @@ def init_cnn(module: nn.Module):
2525
class ResBlock(nn.Module):
2626
'''Resnet block'''
2727

28-
def __init__(self, expansion, ni, nh, stride=1,
28+
def __init__(self, expansion, in_channels, mid_channels, stride=1,
2929
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
3030
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False,
3131
groups=1, dw=False, div_groups=None,
3232
se_module=SEModule, se=False, se_reduction=16
3333
):
3434
super().__init__()
35-
nf, ni = nh * expansion, ni * expansion
35+
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
3636
if div_groups is not None: # check if grops != 1 and div_groups
37-
groups = int(nh / div_groups)
37+
groups = int(mid_channels / div_groups)
3838
if expansion == 1:
39-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride,
40-
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
41-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
42-
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
39+
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride,
40+
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
41+
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
42+
act=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
4343
]
4444
else:
45-
layers = [("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
46-
("conv_1", conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
47-
groups=nh if dw else groups)),
48-
("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
45+
layers = [("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
46+
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
47+
groups=mid_channels if dw else groups)),
48+
("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
4949
]
5050
if se:
51-
layers.append(('se', se_module(nf, se_reduction)))
51+
layers.append(('se', se_module(out_channels, se_reduction)))
5252
if sa:
53-
layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
53+
layers.append(('sa', SimpleSelfAttention(out_channels, ks=1, sym=sym)))
5454
self.convs = nn.Sequential(OrderedDict(layers))
5555
self.pool = noop if stride == 1 else pool
56-
self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)
56+
self.idconv = noop if in_channels == out_channels else conv_layer(in_channels, out_channels, 1, act=False)
5757
self.act_fn = act_fn
5858

5959
def forward(self, x):
@@ -73,8 +73,8 @@ def _make_stem(self):
7373
return nn.Sequential(OrderedDict(stem))
7474

7575

76-
def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
77-
layers = [(f"bl_{i}", self.block(expansion, ni if i == 0 else nf, nf,
76+
def _make_layer(self, expansion, in_channels, out_channels, blocks, stride, sa):
77+
layers = [(f"bl_{i}", self.block(expansion, in_channels if i == 0 else out_channels, out_channels,
7878
stride if i == 0 else 1, sa=sa if i == blocks - 1 else False,
7979
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
8080
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
@@ -87,7 +87,7 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8787
def _make_body(self):
8888
stride = 2 if self.stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8989
blocks = [(f"l_{i}", self._make_layer(self, self.expansion,
90-
ni=self.block_sizes[i], nf=self.block_sizes[i + 1],
90+
in_channels=self.block_sizes[i], out_channels=self.block_sizes[i + 1],
9191
blocks=l, stride=stride if i == 0 else 2,
9292
sa=self.sa if i == 0 else False))
9393
for i, l in enumerate(self.layers)]

model_constructor/yaresnet.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,33 @@
1515
class YaResBlock(nn.Module):
1616
'''YaResBlock. Reduce by pool instead of stride 2'''
1717

18-
def __init__(self, expansion, ni, nh, stride=1,
18+
def __init__(self, expansion, in_channels, mid_channels, stride=1,
1919
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
2020
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False,
2121
groups=1, dw=False, div_groups=None,
2222
se_module=SEModule, se=False, se_reduction=16
2323
):
2424
super().__init__()
25-
nf, ni = nh * expansion, ni * expansion
25+
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
2626
if div_groups is not None: # check if grops != 1 and div_groups
27-
groups = int(nh / div_groups)
27+
groups = int(mid_channels / div_groups)
2828
self.reduce = noop if stride == 1 else pool
29-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1,
30-
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
31-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
32-
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
29+
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=1,
30+
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
31+
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
32+
act=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
3333
] if expansion == 1 else [
34-
("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
35-
("conv_1", conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
36-
groups=nh if dw else groups)),
37-
("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
34+
("conv_0", conv_layer(in_channels, mid_channels, 1, act_fn=act_fn, bn_1st=bn_1st)),
35+
("conv_1", conv_layer(mid_channels, mid_channels, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
36+
groups=mid_channels if dw else groups)),
37+
("conv_2", conv_layer(mid_channels, out_channels, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
3838
]
3939
if se:
40-
layers.append(('se', se_module(nf, se_reduction)))
40+
layers.append(('se', se_module(out_channels, se_reduction)))
4141
if sa:
42-
layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
42+
layers.append(('sa', SimpleSelfAttention(out_channels, ks=1, sym=sym)))
4343
self.convs = nn.Sequential(OrderedDict(layers))
44-
self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)
44+
self.idconv = noop if in_channels == out_channels else conv_layer(in_channels, out_channels, 1, act=False)
4545
self.merge = act_fn
4646

4747
def forward(self, x):

0 commit comments

Comments
 (0)