@@ -296,6 +296,28 @@ def __repr_args__(self) -> list[tuple[str, str]]:
296296 if (str_value := self ._get_str_value (field ))
297297 ]
298298
299+ def __repr_changed_args__ (self ) -> list [str ]:
300+ """Return list repr for changed fields"""
301+ return [
302+ f"{ field } : { self ._get_str_value (field )} "
303+ for field in self .__fields_set__
304+ if field != "name"
305+ ]
306+
307+ def print_cfg (self ) -> None :
308+ """Print full config"""
309+ print (f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )" )
310+
311+ def print_changed (self ) -> None :
312+ """Print changed fields."""
313+ changed_fields = self .__repr_changed_args__ ()
314+ if changed_fields :
315+ print ("Changed fields:" )
316+ for i in changed_fields :
317+ print (" " , i )
318+ else :
319+ print ("Nothing changed" )
320+
299321
300322class ModelConstructor (ModelCfg ):
301323 """Model constructor. As default - xresnet18"""
@@ -331,24 +353,16 @@ def from_cfg(cls, cfg: ModelCfg):
331353 def __call__ (self ) -> nn .Sequential :
332354 """Create model."""
333355 model_name = self .name or self .__class__ .__name__
334- named_sequential = type (model_name , (nn .Sequential ,), {})
356+ named_sequential = type (model_name , (nn .Sequential ,), {}) # create type named as model
335357 model = named_sequential (
336358 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
337359 )
338360 self .init_cnn (model ) # pylint: disable=too-many-function-args
339- extra_repr = self ._get_extra_repr ()
361+ extra_repr = self .__repr_changed_args__ ()
340362 if extra_repr :
341- model .extra_repr = lambda : extra_repr
363+ model .extra_repr = lambda : ", " . join ( extra_repr )
342364 return model
343365
344- def _get_extra_repr (self ) -> str :
345- """Repr for changed fields"""
346- return " " .join (
347- f"{ field } : { self ._get_str_value (field )} ,"
348- for field in self .__fields_set__
349- if field != "name"
350- )[:- 1 ] # strip last comma.
351-
352366 def __repr__ (self ) -> str :
353367 se_repr = self .se .__name__ if self .se else "False" # type: ignore
354368 model_name = self .name or self .__class__ .__name__
@@ -362,10 +376,6 @@ def __repr__(self) -> str:
362376 f" layers: { self .layers } "
363377 )
364378
365- def print_cfg (self ) -> None :
366- """Print full config"""
367- print (f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )" )
368-
369379
370380class XResNet34 (ModelConstructor ):
371381 layers : list [int ] = [3 , 4 , 6 , 3 ]
0 commit comments