Skip to content

Commit 371cac0

Browse files
committed
repr for modelCfg
1 parent 2272b58 commit 371cac0

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

src/model_constructor/model_constructor.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,28 @@ def __repr_args__(self) -> list[tuple[str, str]]:
296296
if (str_value := self._get_str_value(field))
297297
]
298298

299+
def __repr_changed_args__(self) -> list[str]:
300+
"""Return list repr for changed fields"""
301+
return [
302+
f"{field}: {self._get_str_value(field)}"
303+
for field in self.__fields_set__
304+
if field != "name"
305+
]
306+
307+
def print_cfg(self) -> None:
308+
"""Print full config"""
309+
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
310+
311+
def print_changed(self) -> None:
312+
"""Print changed fields."""
313+
changed_fields = self.__repr_changed_args__()
314+
if changed_fields:
315+
print("Changed fields:")
316+
for i in changed_fields:
317+
print(" ", i)
318+
else:
319+
print("Nothing changed")
320+
299321

300322
class ModelConstructor(ModelCfg):
301323
"""Model constructor. As default - xresnet18"""
@@ -331,24 +353,16 @@ def from_cfg(cls, cfg: ModelCfg):
331353
def __call__(self) -> nn.Sequential:
332354
"""Create model."""
333355
model_name = self.name or self.__class__.__name__
334-
named_sequential = type(model_name, (nn.Sequential,), {})
356+
named_sequential = type(model_name, (nn.Sequential,), {}) # create type named as model
335357
model = named_sequential(
336358
OrderedDict([("stem", self.stem), ("body", self.body), ("head", self.head)])
337359
)
338360
self.init_cnn(model) # pylint: disable=too-many-function-args
339-
extra_repr = self._get_extra_repr()
361+
extra_repr = self.__repr_changed_args__()
340362
if extra_repr:
341-
model.extra_repr = lambda: extra_repr
363+
model.extra_repr = lambda: ", ".join(extra_repr)
342364
return model
343365

344-
def _get_extra_repr(self) -> str:
345-
"""Repr for changed fields"""
346-
return " ".join(
347-
f"{field}: {self._get_str_value(field)},"
348-
for field in self.__fields_set__
349-
if field != "name"
350-
)[:-1] # strip last comma.
351-
352366
def __repr__(self) -> str:
353367
se_repr = self.se.__name__ if self.se else "False" # type: ignore
354368
model_name = self.name or self.__class__.__name__
@@ -362,10 +376,6 @@ def __repr__(self) -> str:
362376
f" layers: {self.layers}"
363377
)
364378

365-
def print_cfg(self) -> None:
366-
"""Print full config"""
367-
print(f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})")
368-
369379

370380
class XResNet34(ModelConstructor):
371381
layers: list[int] = [3, 4, 6, 3]

0 commit comments

Comments
 (0)