Skip to content

Commit 12513e9

Browse files
committed
refactor body and layer maker in MC
1 parent 4f41bef commit 12513e9

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

model_constructor/model_constructor.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, expansion, in_channels, mid_channels, stride=1,
5656
self.convs = nn.Sequential(OrderedDict(layers))
5757
if stride != 1 or in_channels != out_channels:
5858
id_layers = []
59-
if stride != 1 and pool is not None:
59+
if stride != 1 and pool is not None: # if pool - reduce by pool else stride 2 art id_conv
6060
id_layers.append(("pool", pool))
6161
if in_channels != out_channels or (stride != 1 and pool is None):
6262
id_layers += [("id_conv", conv_layer(
@@ -86,25 +86,33 @@ def _make_stem(self):
8686
return nn.Sequential(OrderedDict(stem))
8787

8888

89-
def _make_layer(self, expansion, in_channels, out_channels, blocks, stride, sa):
90-
layers = [(f"bl_{i}", self.block(expansion, in_channels if i == 0 else out_channels, out_channels,
91-
stride if i == 0 else 1, sa=sa if i == blocks - 1 else None,
92-
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
93-
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
94-
groups=self.groups, div_groups=self.div_groups,
95-
dw=self.dw, se=self.se))
96-
for i in range(blocks)]
97-
return nn.Sequential(OrderedDict(layers))
89+
def _make_layer(self, layer_id: int) -> nn.Module:
90+
# expansion, in_channels, out_channels, blocks, stride, sa):
91+
stride = 1 if self.stem_pool and layer_id == 0 else 2 # if no pool on stem - stride = 2 for first layer block in body
92+
num_blocks = self.layers[layer_id]
93+
return nn.Sequential(OrderedDict([
94+
(f"bl_{block_num}", self.block(
95+
self.expansion,
96+
self.block_sizes[layer_id] if block_num == 0 else self.block_sizes[layer_id + 1],
97+
self.block_sizes[layer_id + 1],
98+
stride if block_num == 0 else 1,
99+
sa=self.sa if block_num == num_blocks - 1 else None,
100+
conv_layer=self.conv_layer,
101+
act_fn=self.act_fn,
102+
pool=self.pool,
103+
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
104+
groups=self.groups, div_groups=self.div_groups,
105+
dw=self.dw, se=self.se
106+
))
107+
for block_num in range(num_blocks)
108+
]))
98109

99110

100111
def _make_body(self):
101-
stride = 1 if self.stem_pool else 1 # if no pool on stem - stride = 2 for first block in body
102-
blocks = [(f"l_{i}", self._make_layer(self, self.expansion,
103-
in_channels=self.block_sizes[i], out_channels=self.block_sizes[i + 1],
104-
blocks=l, stride=stride if i == 0 else 2,
105-
sa=self.sa if i == 0 else None))
106-
for i, l in enumerate(self.layers)]
107-
return nn.Sequential(OrderedDict(blocks))
112+
return nn.Sequential(OrderedDict([
113+
(f"l_{layer_num}", self._make_layer(self, layer_num))
114+
for layer_num in range(len(self.layers))
115+
]))
108116

109117

110118
def _make_head(self):
@@ -140,7 +148,7 @@ def __init__(self, name='MC', in_chans=3, num_classes=1000,
140148
):
141149
super().__init__()
142150
# se can be bool, int (0, 1) or nn.Module
143-
# se_module - deprecated. Leaved for worning and checks.
151+
# se_module - deprecated. Leaved for warning and checks.
144152
# if stem_pool is False - no pool at stem
145153

146154
params = locals()

0 commit comments

Comments
 (0)