File tree Expand file tree Collapse file tree 2 files changed +11
-0
lines changed Expand file tree Collapse file tree 2 files changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -411,6 +411,12 @@ def body(self):
411411 def from_cfg (cls , cfg : ModelCfg ):
412412 return cls (** cfg .dict ())
413413
414+ @classmethod
415+ def create_model (cls , cfg : Union [ModelCfg , None ] = None ) -> nn .Sequential :
416+ if cfg :
417+ return cls (** cfg .dict ())()
418+ return cls ()()
419+
414420 def __call__ (self ) -> nn .Sequential :
415421 """Create model."""
416422 model_name = self .name or self .__class__ .__name__
Original file line number Diff line number Diff line change @@ -51,6 +51,11 @@ def test_MC():
5151 model = mc ()
5252 pred = model (xb )
5353 assert pred .shape == torch .Size ([bs_test , num_classes ])
54+ # create model from class. Default config, num_classes 1000. ??- how to change
55+ model = ModelConstructor .create_model ()
56+ pred = model (xb )
57+ assert pred .shape == torch .Size ([bs_test , 1000 ])
58+
5459
5560
5661def test_MC_bottleneck ():
You can’t perform that action at this time.
0 commit comments