Skip to content

Commit a3751e5

Browse files
committed
create model at MC
1 parent f10cad8 commit a3751e5

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/model_constructor/model_constructor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ def body(self):
411411
def from_cfg(cls, cfg: ModelCfg):
412412
return cls(**cfg.dict())
413413

414+
@classmethod
415+
def create_model(cls, cfg: Union[ModelCfg, None] = None) -> nn.Sequential:
416+
if cfg:
417+
return cls(**cfg.dict())()
418+
return cls()()
419+
414420
def __call__(self) -> nn.Sequential:
415421
"""Create model."""
416422
model_name = self.name or self.__class__.__name__

tests/test_mc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def test_MC():
5151
model = mc()
5252
pred = model(xb)
5353
assert pred.shape == torch.Size([bs_test, num_classes])
54+
# create model from class. Default config, num_classes 1000. ??- how to change
55+
model = ModelConstructor.create_model()
56+
pred = model(xb)
57+
assert pred.shape == torch.Size([bs_test, 1000])
58+
5459

5560

5661
def test_MC_bottleneck():

0 commit comments

Comments
 (0)