11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , List , Type , TypeVar , Union
3+ from typing import Any , Callable , List , Optional , Type , TypeVar , Union
44
55import torch .nn as nn
66from pydantic import BaseModel , root_validator
@@ -212,7 +212,7 @@ def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
212212class ModelCfg (BaseModel ):
213213 """Model constructor Config. As default - xresnet18"""
214214
215- name : str = "MC"
215+ name : Optional [ str ] = None
216216 in_chans : int = 3
217217 num_classes : int = 1000
218218 block : Type [nn .Module ] = ResBlock
@@ -291,17 +291,36 @@ def from_cfg(cls, cfg: ModelCfg):
291291 return cls (** cfg .dict ())
292292
293293 def __call__ (self ):
294- model = nn .Sequential (
294+ model_name = self .name or self .__class__ .__name__
295+ named_sequential = type (model_name , (nn .Sequential , ), {})
296+ model = named_sequential (
295297 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
296298 )
297299 self .init_cnn (model ) # pylint: disable=too-many-function-args
298- model .extra_repr = lambda : f"{ self .name } "
300+ extra_repr = self .get_extra_repr ()
301+ if extra_repr :
302+ model .extra_repr = lambda : extra_repr
299303 return model
300304
305+ def get_extra_repr (self ) -> str :
306+ return " " .join (
307+ f"{ field } : { self .get_str_value (field )} ,"
308+ for field in self .__fields_set__ if field != "name"
309+ )[:- 1 ]
310+
311+ def get_str_value (self , field : str ) -> str :
312+ value = getattr (self , field )
313+ if isinstance (value , type ):
314+ value = value .__name__
315+ if isinstance (value , partial ):
316+ value = f"{ value .func .__name__ } { value .keywords } "
317+ return value
318+
301319 def __repr__ (self ):
302320 se_repr = self .se .__name__ if self .se else "False" # type: ignore
321+ model_name = self .name or self .__class__ .__name__
303322 return (
304- f"{ self . name } constructor \n "
323+ f"{ model_name } \n "
305324 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
306325 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
307326 f" act_fn: { self .act_fn .__name__ } , sa: { self .sa } , se: { se_repr } \n "
@@ -312,10 +331,8 @@ def __repr__(self):
312331
313332
314333class XResNet34 (ModelConstructor ):
315- name : str = "xresnet34"
316334 layers : list [int ] = [3 , 4 , 6 , 3 ]
317335
318336
319337class XResNet50 (XResNet34 ):
320- name : str = "xresnet50"
321338 expansion : int = 4
0 commit comments