@@ -247,15 +247,25 @@ class Config:
247247 arbitrary_types_allowed = True
248248 extra = "forbid"
249249
250- def extra_repr (self ) -> str :
251- res = ""
252- for k , v in self .dict ().items ():
253- if v is not None :
254- res += f"{ k } : { v } \n "
255- return res
250+ def _get_str_value (self , field : str ) -> str :
251+ value = getattr (self , field )
252+ if isinstance (value , type ):
253+ value = value .__name__
254+ elif isinstance (value , partial ):
255+ value = f"{ value .func .__name__ } { value .keywords } "
256+ elif callable (value ):
257+ value = value .__name__
258+ return value
259+
260+ def __repr__ (self ) -> str :
261+ return f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )"
256262
257- def pprint (self ) -> None :
258- print (self .extra_repr ())
263+ def __repr_args__ (self ):
264+ return [
265+ (field , str_value )
266+ for field in self .__fields__
267+ if (str_value := self ._get_str_value (field ))
268+ ]
259269
260270
261271class ModelConstructor (ModelCfg ):
@@ -296,25 +306,17 @@ def __call__(self):
296306 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
297307 )
298308 self .init_cnn (model ) # pylint: disable=too-many-function-args
299- extra_repr = self .get_extra_repr ()
309+ extra_repr = self ._get_extra_repr ()
300310 if extra_repr :
301311 model .extra_repr = lambda : extra_repr
302312 return model
303313
304- def get_extra_repr (self ) -> str :
314+ def _get_extra_repr (self ) -> str :
305315 return " " .join (
306- f"{ field } : { self .get_str_value (field )} ,"
316+ f"{ field } : { self ._get_str_value (field )} ,"
307317 for field in self .__fields_set__ if field != "name"
308318 )[:- 1 ]
309319
310- def get_str_value (self , field : str ) -> str :
311- value = getattr (self , field )
312- if isinstance (value , type ):
313- value = value .__name__
314- if isinstance (value , partial ):
315- value = f"{ value .func .__name__ } { value .keywords } "
316- return value
317-
318320 def __repr__ (self ):
319321 se_repr = self .se .__name__ if self .se else "False" # type: ignore
320322 model_name = self .name or self .__class__ .__name__
@@ -328,6 +330,10 @@ def __repr__(self):
328330 f" layers: { self .layers } "
329331 )
330332
333+ def print_cfg (self ) -> None :
334+ """Print full config"""
335+ print (f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )" )
336+
331337
332338class XResNet34 (ModelConstructor ):
333339 layers : list [int ] = [3 , 4 , 6 , 3 ]
0 commit comments