Skip to content

Commit c03b9ce

Browse files
committed
rename config: CfgMC -> ModelCfg
1 parent cc4b865 commit c03b9ce

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

src/model_constructor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
2-
from model_constructor.model_constructor import ModelConstructor, ResBlock, CfgMC # noqa F401
2+
from model_constructor.model_constructor import ModelConstructor, ResBlock, ModelCfg # noqa F401
33

44
from model_constructor.version import __version__ # noqa F401

src/model_constructor/model_constructor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
"act_fn",
1212
"ResBlock",
1313
"ModelConstructor",
14-
# "xresnet34",
15-
# "xresnet50",
14+
"XResNet34",
15+
"XResNet50",
1616
]
1717

1818

@@ -119,7 +119,7 @@ def forward(self, x):
119119
return self.act_fn(self.convs(x) + identity)
120120

121121

122-
class CfgMC(BaseModel):
122+
class ModelCfg(BaseModel):
123123
"""Model constructor Config. As default - xresnet18"""
124124

125125
name: str = "MC"
@@ -176,7 +176,7 @@ def init_cnn(module: nn.Module):
176176
init_cnn(layer)
177177

178178

179-
def make_stem(self: CfgMC) -> nn.Sequential:
179+
def make_stem(self: ModelCfg) -> nn.Sequential:
180180
stem: List[tuple[str, nn.Module]] = [
181181
(f"conv_{i}", self.conv_layer(
182182
self.stem_sizes[i], # type: ignore
@@ -197,7 +197,7 @@ def make_stem(self: CfgMC) -> nn.Sequential:
197197
return nn.Sequential(OrderedDict(stem))
198198

199199

200-
def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
200+
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
201201
# expansion, in_channels, out_channels, blocks, stride, sa):
202202
# if no pool on stem - stride = 2 for first layer block in body
203203
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
@@ -233,7 +233,7 @@ def make_layer(cfg: CfgMC, layer_num: int) -> nn.Sequential:
233233
)
234234

235235

236-
def make_body(cfg: CfgMC) -> nn.Sequential:
236+
def make_body(cfg: ModelCfg) -> nn.Sequential:
237237
return nn.Sequential(
238238
OrderedDict(
239239
[
@@ -247,7 +247,7 @@ def make_body(cfg: CfgMC) -> nn.Sequential:
247247
)
248248

249249

250-
def make_head(cfg: CfgMC) -> nn.Sequential:
250+
def make_head(cfg: ModelCfg) -> nn.Sequential:
251251
head = [
252252
("pool", nn.AdaptiveAvgPool2d(1)),
253253
("flat", nn.Flatten()),
@@ -256,7 +256,7 @@ def make_head(cfg: CfgMC) -> nn.Sequential:
256256
return nn.Sequential(OrderedDict(head))
257257

258258

259-
class ModelConstructor(CfgMC):
259+
class ModelConstructor(ModelCfg):
260260
"""Model constructor. As default - xresnet18"""
261261

262262
def __init__(self, **data):
@@ -296,7 +296,7 @@ def body(self):
296296
return self.make_body(self) # type: ignore
297297

298298
@classmethod
299-
def from_cfg(cls, cfg: CfgMC):
299+
def from_cfg(cls, cfg: ModelCfg):
300300
return cls(**cfg.dict())
301301

302302
def __call__(self):

src/model_constructor/yaresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn import Mish
99

1010
from .layers import ConvBnAct
11-
from .model_constructor import CfgMC, ModelConstructor
11+
from .model_constructor import ModelConstructor
1212

1313
__all__ = [
1414
'YaResBlock',

0 commit comments

Comments
 (0)