Skip to content

Commit e5880b8

Browse files
committed
rename _make funcs and self arg to cfg
1 parent c889a6e commit e5880b8

File tree

1 file changed

+26
-30
lines changed

1 file changed

+26
-30
lines changed

src/model_constructor/model_constructor.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def init_cnn(module: nn.Module):
167167
init_cnn(layer)
168168

169169

170-
def _make_stem(self: CfgMC) -> nn.Sequential:
170+
def make_stem(self: CfgMC) -> nn.Sequential:
171171
stem: List[tuple[str, nn.Module]] = [
172172
(f"conv_{i}", self.conv_layer(
173173
self.stem_sizes[i], # type: ignore
@@ -188,34 +188,34 @@ def _make_stem(self: CfgMC) -> nn.Sequential:
188188
return nn.Sequential(OrderedDict(stem))
189189

190190

191-
def _make_layer(self: CfgMC, layer_num: int) -> nn.Sequential:
191+
def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
192192
# expansion, in_channels, out_channels, blocks, stride, sa):
193193
# if no pool on stem - stride = 2 for first layer block in body
194-
stride = 1 if self.stem_pool and layer_num == 0 else 2
195-
num_blocks = self.layers[layer_num]
196-
block_chs = [self.stem_sizes[-1] // self.expansion] + self.block_sizes
194+
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
195+
num_blocks = cfg.layers[layer_num]
196+
block_chs = [cfg.stem_sizes[-1] // cfg.expansion] + cfg.block_sizes
197197
return nn.Sequential(
198198
OrderedDict(
199199
[
200200
(
201201
f"bl_{block_num}",
202-
self.block(
203-
self.expansion, # type: ignore
202+
cfg.block(
203+
cfg.expansion, # type: ignore
204204
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
205205
block_chs[layer_num + 1],
206206
stride if block_num == 0 else 1,
207-
sa=self.sa
207+
sa=cfg.sa
208208
if (block_num == num_blocks - 1) and layer_num == 0
209209
else None,
210-
conv_layer=self.conv_layer,
211-
act_fn=self.act_fn,
212-
pool=self.pool,
213-
zero_bn=self.zero_bn,
214-
bn_1st=self.bn_1st,
215-
groups=self.groups,
216-
div_groups=self.div_groups,
217-
dw=self.dw,
218-
se=self.se,
210+
conv_layer=cfg.conv_layer,
211+
act_fn=cfg.act_fn,
212+
pool=cfg.pool,
213+
zero_bn=cfg.zero_bn,
214+
bn_1st=cfg.bn_1st,
215+
groups=cfg.groups,
216+
div_groups=cfg.div_groups,
217+
dw=cfg.dw,
218+
se=cfg.se,
219219
),
220220
)
221221
for block_num in range(num_blocks)
@@ -224,25 +224,25 @@ def _make_layer(self: CfgMC, layer_num: int) -> nn.Sequential:
224224
)
225225

226226

227-
def _make_body(self: CfgMC) -> nn.Sequential:
227+
def make_body(cfg: CfgMC) -> nn.Sequential:
228228
return nn.Sequential(
229229
OrderedDict(
230230
[
231231
(
232232
f"l_{layer_num}",
233-
self._make_layer(self, layer_num) # type: ignore
233+
cfg._make_layer(cfg, layer_num) # type: ignore
234234
)
235-
for layer_num in range(len(self.layers))
235+
for layer_num in range(len(cfg.layers))
236236
]
237237
)
238238
)
239239

240240

241-
def _make_head(self: CfgMC) -> nn.Sequential:
241+
def make_head(cfg: CfgMC) -> nn.Sequential:
242242
head = [
243243
("pool", nn.AdaptiveAvgPool2d(1)),
244244
("flat", nn.Flatten()),
245-
("fc", nn.Linear(self.block_sizes[-1] * self.expansion, self.num_classes)),
245+
("fc", nn.Linear(cfg.block_sizes[-1] * cfg.expansion, cfg.num_classes)),
246246
]
247247
return nn.Sequential(OrderedDict(head))
248248

@@ -255,13 +255,13 @@ def __post_init__(self):
255255
if self._init_cnn is None:
256256
self._init_cnn = init_cnn
257257
if self._make_stem is None:
258-
self._make_stem = _make_stem
258+
self._make_stem = make_stem
259259
if self._make_layer is None:
260-
self._make_layer = _make_layer
260+
self._make_layer = make_layer
261261
if self._make_body is None:
262-
self._make_body = _make_body
262+
self._make_body = make_body
263263
if self._make_head is None:
264-
self._make_head = _make_head
264+
self._make_head = make_head
265265

266266
if self.stem_sizes[0] != self.in_chans:
267267
self.stem_sizes = [self.in_chans] + self.stem_sizes
@@ -274,10 +274,6 @@ def __post_init__(self):
274274
"Deprecated. Pass se_module as se argument, se_reduction as arg to se."
275275
) # add deprecation warning.
276276

277-
# @property
278-
# def block_sizes(self):
279-
# return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
280-
281277
@property
282278
def stem(self):
283279
return self._make_stem(self) # type: ignore

0 commit comments

Comments
 (0)