We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ae38ec8 commit d5c9c5fCopy full SHA for d5c9c5f
torchhd/tests/test_models.py
@@ -105,11 +105,11 @@ def test_initialization(self, dtype):
105
assert model.weight.device.type == device.type
106
107
def test_fit_ridge_regression(self):
108
- samples = torch.randn(10, 12)
109
- targets = torch.randint(0, 3, (10,))
+ samples = torch.eye(10, 12)
+ targets = torch.arange(10)
110
111
- model = models.IntRVFL(12, 1245, 3)
+ model = models.IntRVFL(12, 1245, 10)
112
model.fit_ridge_regression(samples, targets)
113
114
logits = model(samples)
115
- assert logits.shape == (10, 3)
+ assert logits.shape == (10, 10)
0 commit comments