Skip to content

Commit 1f7e9db

Browse files
authored
Add complex hypervector types (#81)
* WIP implement complex hypervectors * Implement complex level and circular hypervector sets * Add complex tests * Simplify plotting utility * Use unbind function * Add unbind to docs
1 parent d8b3660 commit 1f7e9db

File tree

12 files changed

+345
-207
lines changed

12 files changed

+345
-207
lines changed

docs/functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Operations
2828
:template: function.rst
2929

3030
bind
31+
unbind
3132
bundle
3233
permute
3334
cleanup

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
level_hv,
1111
circular_hv,
1212
bind,
13+
unbind,
1314
bundle,
1415
permute,
1516
)
@@ -27,6 +28,7 @@
2728
"level_hv",
2829
"circular_hv",
2930
"bind",
31+
"unbind",
3032
"bundle",
3133
"permute",
3234
]

torchhd/functional.py

Lines changed: 238 additions & 72 deletions
Large diffs are not rendered by default.

torchhd/structures.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ def contains(self, input: Tensor) -> Tensor:
212212
Examples::
213213
214214
>>> M.contains(letters_hv[0])
215-
tensor([0.4575])
215+
tensor(0.4575)
216216
217217
"""
218-
return functional.cosine_similarity(input, self.value.unsqueeze(0))
218+
return functional.cosine_similarity(input, self.value)
219219

220220
def __len__(self) -> int:
221221
"""Returns the size of the multiset.
@@ -363,7 +363,7 @@ def get(self, key: Tensor) -> Tensor:
363363
tensor([ 1., -1., 1., ..., -1., 1., -1.])
364364
365365
"""
366-
return functional.bind(self.value, key)
366+
return functional.unbind(self.value, key)
367367

368368
def replace(self, key: Tensor, old: Tensor, new: Tensor) -> None:
369369
"""Replace the value from key-value pair in the hash table.
@@ -711,7 +711,7 @@ def pop(self, input: Tensor) -> None:
711711
712712
"""
713713
self.size -= 1
714-
self.value = functional.bind(self.value, input)
714+
self.value = functional.unbind(self.value, input)
715715
self.value = functional.permute(self.value, shifts=-1)
716716

717717
def popleft(self, input: Tensor) -> None:
@@ -727,7 +727,7 @@ def popleft(self, input: Tensor) -> None:
727727
"""
728728
self.size -= 1
729729
rotated_input = functional.permute(input, shifts=len(self))
730-
self.value = functional.bind(self.value, rotated_input)
730+
self.value = functional.unbind(self.value, rotated_input)
731731

732732
def replace(self, index: int, old: Tensor, new: Tensor) -> None:
733733
"""Replace the old hypervector value from the given index, for the new hypervector value.
@@ -744,7 +744,7 @@ def replace(self, index: int, old: Tensor, new: Tensor) -> None:
744744
745745
"""
746746
rotated_old = functional.permute(old, shifts=self.size - index - 1)
747-
self.value = functional.bind(self.value, rotated_old)
747+
self.value = functional.unbind(self.value, rotated_old)
748748

749749
rotated_new = functional.permute(new, shifts=self.size - index - 1)
750750
self.value = functional.bind(self.value, rotated_new)
@@ -880,13 +880,13 @@ def node_neighbors(self, input: Tensor, outgoing=True) -> Tensor:
880880
"""
881881
if self.is_directed:
882882
if outgoing:
883-
permuted_neighbors = functional.bind(self.value, input)
883+
permuted_neighbors = functional.unbind(self.value, input)
884884
return functional.permute(permuted_neighbors, shifts=-1)
885885
else:
886886
permuted_node = functional.permute(input, shifts=1)
887-
return functional.bind(self.value, permuted_node)
887+
return functional.unbind(self.value, permuted_node)
888888
else:
889-
return functional.bind(self.value, input)
889+
return functional.unbind(self.value, input)
890890

891891
def contains(self, input: Tensor) -> Tensor:
892892
"""Returns the cosine similarity of the input vector against the graph.
@@ -898,9 +898,9 @@ def contains(self, input: Tensor) -> Tensor:
898898
899899
>>> e = G.encode_edge(letters_hv[0], letters_hv[1])
900900
>>> G.contains(e)
901-
tensor([1.])
901+
tensor(1.)
902902
"""
903-
return functional.cosine_similarity(input, self.value.unsqueeze(0))
903+
return functional.cosine_similarity(input, self.value)
904904

905905
def clear(self) -> None:
906906
"""Empties the graph.
@@ -1012,7 +1012,7 @@ def get_leaf(self, path: List[str]) -> Tensor:
10121012
hv_path, functional.permute(self.right, shifts=idx)
10131013
)
10141014

1015-
return functional.bind(hv_path, self.value)
1015+
return functional.unbind(self.value, hv_path)
10161016

10171017
def clear(self) -> None:
10181018
"""Empties the tree.
@@ -1084,8 +1084,8 @@ def transition(self, state: Tensor, action: Tensor) -> Tensor:
10841084
tensor([ 1., 1., -1., ..., -1., -1., 1.])
10851085
10861086
"""
1087-
next_state = functional.bind(self.value, state)
1088-
next_state = functional.bind(next_state, action)
1087+
next_state = functional.unbind(self.value, state)
1088+
next_state = functional.unbind(next_state, action)
10891089
return functional.permute(next_state, shifts=-1)
10901090

10911091
def clear(self) -> None:

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,51 @@ def test_value(self, dtype):
4747
assert torch.all(
4848
(hv == True) | (hv == False)
4949
).item(), "values are either 1 or 0"
50+
elif dtype in torch_complex_dtypes:
51+
magnitudes= hv.abs()
52+
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
5053
else:
5154
assert torch.all(
5255
(hv == -1) | (hv == 1)
5356
).item(), "values are either -1 or +1"
5457

58+
5559
hv = functional.circular_hv(8, 1000000, generator=generator, dtype=dtype)
56-
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000
57-
sims_diff = sims[:-1] - sims[1:]
60+
if dtype in torch_complex_dtypes:
61+
sims = functional.cosine_similarity(hv[0], hv)
62+
sims_diff = sims[:-1] - sims[1:]
5863

59-
assert torch.all(
60-
sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1])
61-
), "second half must get more similar"
64+
assert torch.all(
65+
sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1])
66+
), "second half must get more similar"
6267

63-
abs_sims_diff = sims_diff.abs()
64-
assert torch.all(
65-
(0.124 < abs_sims_diff) & (abs_sims_diff < 0.126)
66-
).item(), "similarity changes linearly"
68+
abs_sims_diff = sims_diff.abs()
69+
assert torch.all(
70+
(0.248 < abs_sims_diff) & (abs_sims_diff < 0.252)
71+
).item(), "similarity changes linearly"
72+
else:
73+
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000
74+
sims_diff = sims[:-1] - sims[1:]
75+
76+
assert torch.all(
77+
sims_diff.sign() == torch.tensor([1, 1, 1, 1, -1, -1, -1])
78+
), "second half must get more similar"
79+
80+
abs_sims_diff = sims_diff.abs()
81+
assert torch.all(
82+
(0.124 < abs_sims_diff) & (abs_sims_diff < 0.126)
83+
).item(), "similarity changes linearly"
6784

6885
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0])
6986
@pytest.mark.parametrize("dtype", torch_dtypes)
7087
def test_sparsity(self, sparsity, dtype):
7188
if not supported_dtype(dtype):
7289
return
7390

91+
if dtype in torch_complex_dtypes:
92+
# Complex hypervectors don't support sparsity.
93+
return
94+
7495
generator = torch.Generator()
7596
generator.manual_seed(seed)
7697

@@ -96,12 +117,6 @@ def test_device(self, dtype):
96117

97118
@pytest.mark.parametrize("dtype", torch_dtypes)
98119
def test_dtype(self, dtype):
99-
if dtype in torch_complex_dtypes:
100-
with pytest.raises(NotImplementedError):
101-
functional.circular_hv(3, 26, dtype=dtype)
102-
103-
return
104-
105120
if dtype == torch.uint8:
106121
with pytest.raises(ValueError):
107122
functional.circular_hv(3, 26, dtype=dtype)

torchhd/tests/basis_hv/test_identity_hv.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ def test_device(self, dtype):
4747

4848
@pytest.mark.parametrize("dtype", torch_dtypes)
4949
def test_dtype(self, dtype):
50-
if dtype in torch_complex_dtypes:
51-
with pytest.raises(NotImplementedError):
52-
functional.identity_hv(3, 26, dtype=dtype)
53-
54-
return
55-
5650
if dtype == torch.uint8:
5751
with pytest.raises(ValueError):
5852
functional.identity_hv(3, 26, dtype=dtype)

torchhd/tests/basis_hv/test_level_hv.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,48 @@ def test_value(self, dtype):
4747
assert torch.all(
4848
(hv == True) | (hv == False)
4949
).item(), "values are either 1 or 0"
50+
elif dtype in torch_complex_dtypes:
51+
magnitudes= hv.abs()
52+
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
5053
else:
5154
assert torch.all(
5255
(hv == -1) | (hv == 1)
5356
).item(), "values are either -1 or +1"
5457

5558
# look at the similarity profile w.r.t. the first hypervector
56-
sims = functional.hamming_similarity(hv[0], hv).float() / 10000
57-
sims_diff = sims[:-1] - sims[1:]
58-
assert torch.all(sims_diff > 0).item(), "similarity must be decreasing"
59+
if dtype in torch_complex_dtypes:
60+
sims = functional.cosine_similarity(hv[0], hv)
61+
sims_diff = sims[:-1] - sims[1:]
62+
assert torch.all(sims_diff > 0).item(), "similarity must be decreasing"
5963

60-
hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype)
61-
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000
62-
sims_diff = sims[:-1] - sims[1:]
63-
assert torch.all(
64-
(0.124 < sims_diff) & (sims_diff < 0.126)
65-
).item(), "similarity decreases linearly"
64+
hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype)
65+
sims = functional.cosine_similarity(hv[0], hv)
66+
sims_diff = sims[:-1] - sims[1:]
67+
assert torch.all(
68+
(0.248 < sims_diff) & (sims_diff < 0.252)
69+
).item(), "similarity decreases linearly"
70+
else:
71+
sims = functional.hamming_similarity(hv[0], hv).float() / 10000
72+
sims_diff = sims[:-1] - sims[1:]
73+
assert torch.all(sims_diff > 0).item(), "similarity must be decreasing"
74+
75+
hv = functional.level_hv(5, 1000000, generator=generator, dtype=dtype)
76+
sims = functional.hamming_similarity(hv[0], hv).float() / 1000000
77+
sims_diff = sims[:-1] - sims[1:]
78+
assert torch.all(
79+
(0.124 < sims_diff) & (sims_diff < 0.126)
80+
).item(), "similarity decreases linearly"
6681

6782
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0])
6883
@pytest.mark.parametrize("dtype", torch_dtypes)
6984
def test_sparsity(self, sparsity, dtype):
7085
if not supported_dtype(dtype):
7186
return
7287

88+
if dtype in torch_complex_dtypes:
89+
# Complex hypervectors don't support sparsity.
90+
return
91+
7392
generator = torch.Generator()
7493
generator.manual_seed(seed)
7594

@@ -95,12 +114,6 @@ def test_device(self, dtype):
95114

96115
@pytest.mark.parametrize("dtype", torch_dtypes)
97116
def test_dtype(self, dtype):
98-
if dtype in torch_complex_dtypes:
99-
with pytest.raises(NotImplementedError):
100-
functional.level_hv(3, 26, dtype=dtype)
101-
102-
return
103-
104117
if dtype == torch.uint8:
105118
with pytest.raises(ValueError):
106119
functional.level_hv(3, 26, dtype=dtype)

torchhd/tests/basis_hv/test_random_hv.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,26 @@ def test_value(self, dtype):
4242
generator = torch.Generator()
4343
generator.manual_seed(seed)
4444

45+
hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator)
46+
4547
if dtype == torch.bool:
46-
hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator)
4748
assert torch.all((hv == False) | (hv == True)).item()
48-
49-
return
50-
51-
hv = functional.random_hv(100, 10000, dtype=dtype, generator=generator)
52-
assert torch.all((hv == -1) | (hv == 1)).item()
49+
elif dtype in torch_complex_dtypes:
50+
magnitudes= hv.abs()
51+
assert torch.allclose(magnitudes, torch.tensor(1.0, dtype=magnitudes.dtype)), "magnitude must be 1"
52+
else:
53+
assert torch.all((hv == -1) | (hv == 1)).item()
5354

5455
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.756, 1.0])
5556
@pytest.mark.parametrize("dtype", torch_dtypes)
5657
def test_sparsity(self, sparsity, dtype):
5758
if not supported_dtype(dtype):
5859
return
5960

61+
if dtype in torch_complex_dtypes:
62+
# Complex hypervectors don't support sparsity.
63+
return
64+
6065
generator = torch.Generator()
6166
generator.manual_seed(seed)
6267

@@ -83,14 +88,24 @@ def test_orthogonality(self, dtype):
8388
generator = torch.Generator()
8489
generator.manual_seed(seed)
8590

86-
sims = [None] * 100
87-
for i in range(100):
88-
hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator)
89-
sims[i] = functional.hamming_similarity(hv[0], hv[1].unsqueeze(0))
90-
91-
sims = torch.cat(sims).float() / 10000
92-
assert within(sims.mean().item(), 0.5, 0.001)
93-
assert sims.std().item() < 0.01
91+
if dtype in torch_complex_dtypes:
92+
sims = [None] * 100
93+
for i in range(100):
94+
hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator)
95+
sims[i] = functional.cosine_similarity(hv[0], hv[1])
96+
97+
sims = torch.stack(sims).float() / 10000
98+
assert within(sims.mean().item(), 0.0, 0.001)
99+
assert sims.std().item() < 0.01
100+
else:
101+
sims = [None] * 100
102+
for i in range(100):
103+
hv = functional.random_hv(2, 10000, dtype=dtype, generator=generator)
104+
sims[i] = functional.hamming_similarity(hv[0], hv[1].unsqueeze(0))
105+
106+
sims = torch.stack(sims).float() / 10000
107+
assert within(sims.mean().item(), 0.5, 0.001)
108+
assert sims.std().item() < 0.01
94109

95110
@pytest.mark.parametrize("dtype", torch_dtypes)
96111
def test_device(self, dtype):
@@ -103,12 +118,6 @@ def test_device(self, dtype):
103118

104119
@pytest.mark.parametrize("dtype", torch_dtypes)
105120
def test_dtype(self, dtype):
106-
if dtype in torch_complex_dtypes:
107-
with pytest.raises(NotImplementedError):
108-
functional.random_hv(3, 26, dtype=dtype)
109-
110-
return
111-
112121
if dtype == torch.uint8:
113122
with pytest.raises(ValueError):
114123
functional.random_hv(3, 26, dtype=dtype)

0 commit comments

Comments
 (0)