Skip to content

Commit c6b1f56

Browse files
authored
Merge pull request #82 from ayasyrev/model_name
Model name refactored
2 parents 6c55cb0 + fbd5b00 commit c6b1f56

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

src/model_constructor/model_constructor.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, List, Type, TypeVar, Union
3+
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
44

55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
@@ -212,7 +212,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
212212
class ModelCfg(BaseModel):
213213
"""Model constructor Config. As default - xresnet18"""
214214

215-
name: str = "MC"
215+
name: Optional[str] = None
216216
in_chans: int = 3
217217
num_classes: int = 1000
218218
block: Type[nn.Module] = ResBlock
@@ -291,17 +291,36 @@ def from_cfg(cls, cfg: ModelCfg):
291291
return cls(**cfg.dict())
292292

293293
def __call__(self):
294-
model = nn.Sequential(
294+
model_name = self.name or self.__class__.__name__
295+
named_sequential = type(model_name, (nn.Sequential, ), {})
296+
model = named_sequential(
295297
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
296298
)
297299
self.init_cnn(model) # pylint: disable=too-many-function-args
298-
model.extra_repr = lambda: f"{self.name}"
300+
extra_repr = self.get_extra_repr()
301+
if extra_repr:
302+
model.extra_repr = lambda: extra_repr
299303
return model
300304

305+
def get_extra_repr(self) -> str:
306+
return " ".join(
307+
f"{field}: {self.get_str_value(field)},"
308+
for field in self.__fields_set__ if field != "name"
309+
)[:-1]
310+
311+
def get_str_value(self, field: str) -> str:
312+
value = getattr(self, field)
313+
if isinstance(value, type):
314+
value = value.__name__
315+
if isinstance(value, partial):
316+
value = f"{value.func.__name__} {value.keywords}"
317+
return value
318+
301319
def __repr__(self):
302320
se_repr = self.se.__name__ if self.se else "False" # type: ignore
321+
model_name = self.name or self.__class__.__name__
303322
return (
304-
f"{self.name} constructor\n"
323+
f"{model_name}\n"
305324
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
306325
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
307326
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {se_repr}\n"
@@ -312,10 +331,8 @@ def __repr__(self):
312331

313332

314333
class XResNet34(ModelConstructor):
315-
name: str = "xresnet34"
316334
layers: list[int] = [3, 4, 6, 3]
317335

318336

319337
class XResNet50(XResNet34):
320-
name: str = "xresnet50"
321338
expansion: int = 4

src/model_constructor/yaresnet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,11 @@ def forward(self, x):
124124

125125

126126
class YaResNet34(ModelConstructor):
127-
name: str = 'YaResnet34'
128127
block: Type[nn.Module] = YaResBlock
129128
expansion: int = 1
130129
layers: List[int] = [3, 4, 6, 3]
131130
act_fn: Type[nn.Module] = Mish
132131

133132

134133
class YaResNet50(YaResNet34):
135-
name: str = 'YaResnet50'
136134
expansion: int = 4

tests/test_mc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def test_MC():
1111
"""test ModelConstructor"""
1212
img_size = 16
1313
mc = ModelConstructor()
14+
assert "name=None" in str(mc)
15+
mc.name = "MC"
1416
assert "name='MC'" in str(mc)
1517
model = mc()
1618
xb = torch.randn(bs_test, 3, img_size, img_size)

0 commit comments

Comments
 (0)