Skip to content

Commit 90c6926

Browse files
Support for all VSA models in the embeddings module (#111)
* Update empty, identity and random embeddings * Level embeddings * [github-action] formatting fixes * Implement VSA model choice for Thermometer and Circular * [github-action] formatting fixes * Update examples to latest version of torchmetrics * [github-action] formatting fixes Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 909e12d commit 90c6926

File tree

14 files changed

+639
-182
lines changed

14 files changed

+639
-182
lines changed

docs/embeddings.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ torchhd.embeddings
99
:toctree: generated/
1010
:template: class.rst
1111

12+
Empty
1213
Identity
1314
Random
1415
Level

examples/emg_hand_gestures.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def experiment(subjects=[0]):
6666
train_ld = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
6767
test_ld = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
6868

69-
model = Model(len(ds.classes), ds[0][0].size(-2), ds[0][0].size(-1))
69+
num_classes = len(ds.classes)
70+
model = Model(num_classes, ds[0][0].size(-2), ds[0][0].size(-1))
7071
model = model.to(device)
7172

7273
with torch.no_grad():
@@ -79,7 +80,7 @@ def experiment(subjects=[0]):
7980

8081
model.classify.weight[:] = F.normalize(model.classify.weight)
8182

82-
accuracy = torchmetrics.Accuracy()
83+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
8384

8485
with torch.no_grad():
8586
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/graphhd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def forward(self, x):
121121

122122
model.classify.weight[:] = F.normalize(model.classify.weight)
123123

124-
accuracy = torchmetrics.Accuracy()
124+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=graphs.num_classes)
125125

126126
with torch.no_grad():
127127
for samples in tqdm(test_ld, desc="Testing"):

examples/language_recognition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def forward(self, x):
7474
return logit
7575

7676

77-
model = Model(len(train_ds.classes), NUM_TOKENS)
77+
num_classes = len(train_ds.classes)
78+
model = Model(num_classes, NUM_TOKENS)
7879
model = model.to(device)
7980

8081
with torch.no_grad():
@@ -87,7 +88,7 @@ def forward(self, x):
8788

8889
model.classify.weight[:] = F.normalize(model.classify.weight)
8990

90-
accuracy = torchmetrics.Accuracy()
91+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
9192

9293
with torch.no_grad():
9394
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def forward(self, x):
5353
return logit
5454

5555

56-
model = Model(len(train_ds.classes), IMG_SIZE)
56+
num_classes = len(train_ds.classes)
57+
model = Model(num_classes, IMG_SIZE)
5758
model = model.to(device)
5859

5960
with torch.no_grad():
@@ -66,8 +67,7 @@ def forward(self, x):
6667

6768
model.classify.weight[:] = F.normalize(model.classify.weight)
6869

69-
accuracy = torchmetrics.Accuracy()
70-
70+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
7171

7272
with torch.no_grad():
7373
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/mnist_hugging_face.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
88
import torchmetrics
9-
from tqdm import tqdm
9+
10+
# Note: this example requires the accelerate library: https://github.com/huggingface/accelerate
1011
from accelerate import Accelerator
12+
from tqdm import tqdm
1113

1214
import torchhd
1315
from torchhd import embeddings
@@ -54,7 +56,8 @@ def forward(self, x):
5456
return logit
5557

5658

57-
model = Model(len(train_ds.classes), IMG_SIZE)
59+
num_classes = len(train_ds.classes)
60+
model = Model(num_classes, IMG_SIZE)
5861
model.to(device)
5962

6063
model, train_ld, test_ld = accelerator.prepare(model, train_ld, test_ld)
@@ -69,8 +72,7 @@ def forward(self, x):
6972

7073
model.classify.weight[:] = F.normalize(model.classify.weight)
7174

72-
accuracy = torchmetrics.Accuracy()
73-
75+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
7476

7577
with torch.no_grad():
7678
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/mnist_nonlinear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
test_ds = MNIST("../data", train=False, transform=transform, download=True)
2929
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
3030

31+
3132
class Model(nn.Module):
3233
def __init__(self, num_classes, size):
3334
super(Model, self).__init__()
@@ -50,6 +51,7 @@ def forward(self, x):
5051
return logit
5152

5253

54+
num_classes = len(train_ds.classes)
5355
model = Model(len(train_ds.classes), IMG_SIZE)
5456
model = model.to(device)
5557

@@ -63,8 +65,7 @@ def forward(self, x):
6365

6466
model.classify.weight[:] = F.normalize(model.classify.weight)
6567

66-
accuracy = torchmetrics.Accuracy()
67-
68+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
6869

6970
with torch.no_grad():
7071
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/mnist_torch_lightning.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
88
import torchmetrics
9-
from tqdm import tqdm
9+
10+
# Note: this example requires the pytorch-lightning library: https://www.pytorchlightning.ai
1011
import pytorch_lightning as pl
12+
from tqdm import tqdm
1113

1214
import torchhd
1315
from torchhd import embeddings
@@ -60,7 +62,8 @@ def configure_optimizers(self):
6062
return
6163

6264

63-
model = Model(len(train_ds.classes), IMG_SIZE)
65+
num_classes = len(train_ds.classes)
66+
model = Model(num_classes, IMG_SIZE)
6467
trainer = pl.Trainer(
6568
accelerator="cpu",
6669
devices=1,
@@ -77,8 +80,7 @@ def configure_optimizers(self):
7780

7881
model.classify.weight[:] = F.normalize(model.classify.weight)
7982

80-
accuracy = torchmetrics.Accuracy()
81-
83+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
8284

8385
with torch.no_grad():
8486
for samples, labels in tqdm(test_ld, desc="Testing"):

examples/random_projection.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,13 @@ def __init__(self, num_classes, size):
5858
self.target = embeddings.Level(
5959
500, DIMENSIONS, low=MIN_TEMPERATURE, high=MAX_TEMPERATURE
6060
)
61-
self.project = embeddings.Projection(size, DIMENSIONS)
62-
self.bias = nn.parameter.Parameter(torch.empty(DIMENSIONS), requires_grad=False)
63-
self.bias.data.uniform_(0, 2 * math.pi)
61+
self.project = embeddings.Sinusoid(size, DIMENSIONS)
6462

6563
self.regression = nn.Linear(DIMENSIONS, num_classes, bias=False)
6664
self.regression.weight.data.fill_(0.0)
6765

6866
def encode(self, x):
69-
enc = self.project(x)
70-
sample_hv = torch.cos(enc + self.bias) * torch.sin(enc)
67+
sample_hv = self.project(x)
7168
return torchhd.hard_quantize(sample_hv)
7269

7370
def forward(self, x):

examples/reghd.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
DIMENSIONS = 10000 # number of hypervector dimensions
1919
NUM_FEATURES = 5 # number of features in dataset
2020

21-
ds = AirfoilSelfNoise("../data", download=False)
21+
ds = AirfoilSelfNoise("../data", download=True)
2222

2323
# Get necessary statistics for data and target transform
2424
STD_DEVS = ds.data.std(0)
@@ -57,14 +57,10 @@ def __init__(self, num_classes, size):
5757

5858
self.lr = 0.00001
5959
self.M = torch.zeros(1, DIMENSIONS)
60-
self.project = embeddings.Projection(size, DIMENSIONS)
61-
self.project.weight.data.normal_(0, 1)
62-
self.bias = nn.parameter.Parameter(torch.empty(DIMENSIONS), requires_grad=False)
63-
self.bias.data.uniform_(0, 2 * math.pi)
60+
self.project = embeddings.Sinusoid(size, DIMENSIONS)
6461

6562
def encode(self, x):
66-
enc = self.project(x)
67-
sample_hv = torch.cos(enc + self.bias) * torch.sin(enc)
63+
sample_hv = self.project(x)
6864
return torchhd.hard_quantize(sample_hv)
6965

7066
def model_update(self, x, y):

0 commit comments

Comments
 (0)