Skip to content

Commit d5c9c5f

Browse files
committed
Make ridge regression test deterministic
1 parent ae38ec8 commit d5c9c5f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchhd/tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def test_initialization(self, dtype):
105105
assert model.weight.device.type == device.type
106106

107107
def test_fit_ridge_regression(self):
108-
samples = torch.randn(10, 12)
109-
targets = torch.randint(0, 3, (10,))
108+
samples = torch.eye(10, 12)
109+
targets = torch.arange(10)
110110

111-
model = models.IntRVFL(12, 1245, 3)
111+
model = models.IntRVFL(12, 1245, 10)
112112
model.fit_ridge_regression(samples, targets)
113113

114114
logits = model(samples)
115-
assert logits.shape == (10, 3)
115+
assert logits.shape == (10, 10)

0 commit comments

Comments
 (0)