Skip to content

Commit 812950c

Browse files
committed
kwargs to mc create model
1 parent a3751e5 commit 812950c

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/model_constructor/model_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,10 @@ def from_cfg(cls, cfg: ModelCfg):
412412
return cls(**cfg.dict())
413413

414414
@classmethod
415-
def create_model(cls, cfg: Union[ModelCfg, None] = None) -> nn.Sequential:
415+
def create_model(cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]) -> nn.Sequential:
416416
if cfg:
417417
return cls(**cfg.dict())()
418-
return cls()()
418+
return cls(**kwargs)()
419419

420420
def __call__(self) -> nn.Sequential:
421421
"""Create model."""

tests/test_mc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def test_MC():
5656
pred = model(xb)
5757
assert pred.shape == torch.Size([bs_test, 1000])
5858

59+
model = ModelConstructor.create_model(num_classes=num_classes)
60+
pred = model(xb)
61+
assert pred.shape == torch.Size([bs_test, num_classes])
5962

6063

6164
def test_MC_bottleneck():

0 commit comments

Comments
 (0)