Skip to content

Commit c889a6e

Browse files
committed
rename cfg
1 parent bb83046 commit c889a6e

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

src/model_constructor/model_constructor.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def forward(self, x):
123123

124124

125125
@dataclass
126-
class ModelConstructorCfg:
127-
"""Model constructor. As default - xresnet18"""
126+
class CfgMC:
127+
"""Model constructor Config. As default - xresnet18"""
128128

129129
name: str = "MC"
130130
in_chans: int = 3
@@ -167,7 +167,7 @@ def init_cnn(module: nn.Module):
167167
init_cnn(layer)
168168

169169

170-
def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
170+
def _make_stem(self: CfgMC) -> nn.Sequential:
171171
stem: List[tuple[str, nn.Module]] = [
172172
(f"conv_{i}", self.conv_layer(
173173
self.stem_sizes[i], # type: ignore
@@ -188,7 +188,7 @@ def _make_stem(self: ModelConstructorCfg) -> nn.Sequential:
188188
return nn.Sequential(OrderedDict(stem))
189189

190190

191-
def _make_layer(self: ModelConstructorCfg, layer_num: int) -> nn.Sequential:
191+
def _make_layer(self: CfgMC, layer_num: int) -> nn.Sequential:
192192
# expansion, in_channels, out_channels, blocks, stride, sa):
193193
# if no pool on stem - stride = 2 for first layer block in body
194194
stride = 1 if self.stem_pool and layer_num == 0 else 2
@@ -224,7 +224,7 @@ def _make_layer(self: ModelConstructorCfg, layer_num: int) -> nn.Sequential:
224224
)
225225

226226

227-
def _make_body(self: ModelConstructorCfg) -> nn.Sequential:
227+
def _make_body(self: CfgMC) -> nn.Sequential:
228228
return nn.Sequential(
229229
OrderedDict(
230230
[
@@ -238,7 +238,7 @@ def _make_body(self: ModelConstructorCfg) -> nn.Sequential:
238238
)
239239

240240

241-
def _make_head(self: ModelConstructorCfg) -> nn.Sequential:
241+
def _make_head(self: CfgMC) -> nn.Sequential:
242242
head = [
243243
("pool", nn.AdaptiveAvgPool2d(1)),
244244
("flat", nn.Flatten()),
@@ -248,7 +248,7 @@ def _make_head(self: ModelConstructorCfg) -> nn.Sequential:
248248

249249

250250
@dataclass
251-
class ModelConstructor(ModelConstructorCfg):
251+
class ModelConstructor(CfgMC):
252252
"""Model constructor. As default - xresnet18"""
253253

254254
def __post_init__(self):
@@ -291,7 +291,7 @@ def body(self):
291291
return self._make_body(self) # type: ignore
292292

293293
@classmethod
294-
def from_cfg(cls, cfg: ModelConstructorCfg):
294+
def from_cfg(cls, cfg: CfgMC):
295295
return cls(**asdict(cfg))
296296

297297
def __call__(self):
@@ -313,9 +313,11 @@ def print_cfg(self):
313313
f" layers: {self.layers}"
314314
)
315315

316-
# xresnet34 = partial(
317-
# ModelConstructor, name="xresnet34", expansion=1, layers=[3, 4, 6, 3]
318-
# )
319-
# xresnet50 = partial(
320-
# ModelConstructor, name="xresnet34", expansion=4, layers=[3, 4, 6, 3]
321-
# )
316+
317+
xresnet34 = ModelConstructor.from_cfg(
318+
CfgMC(name="xresnet34", expansion=1, layers=[3, 4, 6, 3])
319+
)
320+
321+
xresnet50 = ModelConstructor.from_cfg(
322+
CfgMC(name="xresnet34", expansion=4, layers=[3, 4, 6, 3])
323+
)

0 commit comments

Comments
 (0)