Skip to content

Commit 49a0fa4

Browse files
mikeheddespverges
andauthored
Add distinct sequence (#13)
* Fixed bugs structures * Fixed bugs structures 3 * Add distinct sequence and replace methods * Fix missing device allocation in structures * Remove thresholding from structures and add overloads * Format Python code Co-authored-by: verges <pverges8@gmail.com> Co-authored-by: formatting <mikeheddes@users.noreply.github.com>
1 parent 2e7f4ad commit 49a0fa4

File tree

3 files changed

+182
-78
lines changed

3 files changed

+182
-78
lines changed

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Welcome to the Torchhd documentation!
2-
===============================
2+
=====================================
33

44
*Torchhd* is a Python library dedicated to Hyperdimensional Computing and the operations related to it.
55

docs/structures.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Structures
1010
Memory
1111
Multiset
1212
Sequence
13+
DistinctSequence
1314
Graph
1415
Tree
1516
FiniteStateAutomata

torchhd/structures.py

Lines changed: 180 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, List, Optional, Tuple
1+
from typing import Any, List, Optional, Tuple, overload
22
import torch
3+
from torch import Tensor
34

45
import torchhd.functional as functional
56

@@ -9,19 +10,19 @@ class Memory:
910

1011
def __init__(self, threshold=0.5):
1112
self.threshold = threshold
12-
self.keys: List[torch.Tensor] = []
13+
self.keys: List[Tensor] = []
1314
self.values: List[Any] = []
1415

1516
def __len__(self) -> int:
1617
"""Returns the number of items in memory"""
1718
return len(self.values)
1819

19-
def add(self, key: torch.Tensor, value: Any) -> None:
20+
def add(self, key: Tensor, value: Any) -> None:
2021
"""Adds one (key, value) pair to memory"""
2122
self.keys.append(key)
2223
self.values.append(value)
2324

24-
def _get_index(self, key: torch.Tensor) -> int:
25+
def _get_index(self, key: Tensor) -> int:
2526
key_stack = torch.stack(self.keys, dim=0)
2627
sim = functional.cosine_similarity(key, key_stack)
2728
value, index = torch.max(sim, 0)
@@ -31,147 +32,249 @@ def _get_index(self, key: torch.Tensor) -> int:
3132

3233
return index
3334

34-
def __getitem__(self, key: torch.Tensor) -> Tuple[torch.Tensor, Any]:
35+
def __getitem__(self, key: Tensor) -> Tuple[Tensor, Any]:
3536
"""Get the (key, value) pair with an approximate key"""
3637
index = self._get_index(key)
3738
return self.keys[index], self.values[index]
3839

39-
def __setitem__(self, key: torch.Tensor, value: Any) -> None:
40+
def __setitem__(self, key: Tensor, value: Any) -> None:
4041
"""Set the value of an (key, value) pair with an approximate key"""
4142
index = self._get_index(key)
4243
self.values[index] = value
4344

44-
def __delitem__(self, key: torch.Tensor) -> None:
45+
def __delitem__(self, key: Tensor) -> None:
4546
"""Delete the (key, value) pair with an approximate key"""
4647
index = self._get_index(key)
4748
del self.keys[index]
4849
del self.values[index]
4950

5051

5152
class Multiset:
52-
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
53-
self.threshold = threshold
54-
self.cardinality = 0
55-
dtype = dtype if dtype is not None else torch.get_default_dtype()
56-
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
53+
@overload
54+
def __init__(self, dimensions: int, *, device=None, dtype=None):
55+
...
56+
57+
@overload
58+
def __init__(self, input: Tensor, *, size=0):
59+
...
60+
61+
def __init__(self, dim_or_input: int, **kwargs):
62+
self.size = kwargs.get("size", 0)
63+
if torch.is_tensor(dim_or_input):
64+
self.value = dim_or_input
65+
else:
66+
dtype = kwargs.get("dtype", torch.get_default_dtype())
67+
device = kwargs.get("device", None)
68+
self.value = torch.zeros(dim_or_input, dtype=dtype, device=device)
5769

58-
def add(self, input: torch.Tensor) -> None:
70+
def add(self, input: Tensor) -> None:
5971
self.value = functional.bundle(self.value, input)
60-
self.cardinality += 1
72+
self.size += 1
6173

62-
def remove(self, input: torch.Tensor) -> None:
63-
if input not in self:
64-
return
74+
def remove(self, input: Tensor) -> None:
6575
self.value = functional.bundle(self.value, -input)
66-
self.cardinality -= 1
76+
self.size -= 1
6777

68-
def __contains__(self, input: torch.Tensor):
69-
sim = functional.cosine_similarity(input, self.values.unsqueeze(0))
70-
return sim.item() > self.threshold
78+
def contains(self, input: Tensor) -> Tensor:
79+
return functional.cosine_similarity(input, self.value.unsqueeze(0))
7180

7281
def __len__(self) -> int:
73-
return self.cardinality
82+
return self.size
7483

7584
@classmethod
76-
def from_ngrams(cls, input: torch.Tensor, n=3, threshold=0.5):
77-
instance = cls(input.size(-1), threshold, input.device, input.dtype)
78-
instance.value = functional.ngrams(input, n)
79-
return instance
85+
def from_ngrams(cls, input: Tensor, n=3):
86+
value = functional.ngrams(input, n)
87+
return cls(value, size=input.size(-2) - n + 1)
8088

8189
@classmethod
82-
def from_tensors(cls, input: torch.Tensor, dim=-2, threshold=0.5):
83-
instance = cls(input.size(-1), threshold, input.device, input.dtype)
84-
instance.value = functional.multiset(input=input, dim=dim)
85-
return instance
90+
def from_tensor(cls, input: Tensor):
91+
value = functional.multiset(input, dim=-2)
92+
return cls(value, size=input.size(-2))
8693

8794

8895
class Sequence:
89-
def __init__(self, dimensions, threshold=0.5, device=None, dtype=None):
90-
self.length = 0
91-
self.threshold = threshold
92-
dtype = dtype if dtype is not None else torch.get_default_dtype()
93-
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
96+
@overload
97+
def __init__(self, dimensions: int, *, device=None, dtype=None):
98+
...
99+
100+
@overload
101+
def __init__(self, input: Tensor, *, length=0):
102+
...
103+
104+
def __init__(self, dim_or_input: int, **kwargs):
105+
self.length = kwargs.get("length", 0)
106+
if torch.is_tensor(dim_or_input):
107+
self.value = dim_or_input
108+
else:
109+
dtype = kwargs.get("dtype", torch.get_default_dtype())
110+
device = kwargs.get("device", None)
111+
self.value = torch.zeros(dim_or_input, dtype=dtype, device=device)
94112

95-
def append(self, input: torch.Tensor) -> None:
113+
def append(self, input: Tensor) -> None:
96114
rotated_value = functional.permute(self.value, shifts=1)
97115
self.value = functional.bundle(input, rotated_value)
116+
self.length += 1
98117

99-
def appendleft(self, input: torch.Tensor) -> None:
118+
def appendleft(self, input: Tensor) -> None:
100119
rotated_input = functional.permute(input, shifts=len(self))
101120
self.value = functional.bundle(self.value, rotated_input)
121+
self.length += 1
102122

103-
def pop(self, input: torch.Tensor) -> Optional[torch.Tensor]:
123+
def pop(self, input: Tensor) -> None:
124+
self.length -= 1
104125
self.value = functional.bundle(self.value, -input)
105126
self.value = functional.permute(self.value, shifts=-1)
106-
self.length -= 1
107127

108-
def popleft(self, input: torch.Tensor) -> None:
109-
rotated_input = functional.permute(input, shifts=len(self) + 1)
128+
def popleft(self, input: Tensor) -> None:
129+
self.length -= 1
130+
rotated_input = functional.permute(input, shifts=len(self))
110131
self.value = functional.bundle(self.value, -rotated_input)
132+
133+
def replace(self, index: int, old: Tensor, new: Tensor) -> None:
134+
rotated_old = functional.permute(old, shifts=-self.length + index + 1)
135+
self.value = functional.bundle(self.value, -rotated_old)
136+
137+
rotated_new = functional.permute(new, shifts=-self.length + index + 1)
138+
self.value = functional.bundle(self.value, rotated_new)
139+
140+
def concat(self, seq: "Sequence") -> "Sequence":
141+
value = functional.permute(self.value, shifts=len(seq))
142+
value = functional.bundle(value, seq.value)
143+
return Sequence(value, length=len(self) + len(seq))
144+
145+
def __getitem__(self, index: int) -> Tensor:
146+
return functional.permute(self.value, shifts=-self.length + index + 1)
147+
148+
def __len__(self) -> int:
149+
return self.length
150+
151+
152+
class DistinctSequence:
153+
@overload
154+
def __init__(self, dimensions: int, *, device=None, dtype=None):
155+
...
156+
157+
@overload
158+
def __init__(self, input: Tensor, *, length=0):
159+
...
160+
161+
def __init__(self, dim_or_input: int, **kwargs):
162+
self.length = kwargs.get("length", 0)
163+
if torch.is_tensor(dim_or_input):
164+
self.value = dim_or_input
165+
else:
166+
dtype = kwargs.get("dtype", torch.get_default_dtype())
167+
device = kwargs.get("device", None)
168+
self.value = torch.zeros(dim_or_input, dtype=dtype, device=device)
169+
170+
def append(self, input: Tensor) -> None:
171+
rotated_value = functional.permute(self.value, shifts=1)
172+
self.value = functional.bind(input, rotated_value)
173+
self.length += 1
174+
175+
def appendleft(self, input: Tensor) -> None:
176+
rotated_input = functional.permute(input, shifts=len(self))
177+
self.value = functional.bind(self.value, rotated_input)
178+
self.length += 1
179+
180+
def pop(self, input: Tensor) -> None:
181+
self.length -= 1
182+
self.value = functional.bind(self.value, input)
183+
self.value = functional.permute(self.value, shifts=-1)
184+
185+
def popleft(self, input: Tensor) -> None:
111186
self.length -= 1
187+
rotated_input = functional.permute(input, shifts=len(self))
188+
self.value = functional.bind(self.value, rotated_input)
112189

113-
def __getitem__(self, index: int) -> torch.Tensor:
114-
rotated_value = functional.permute(self.value, shifts=-index)
115-
return rotated_value
190+
def replace(self, index: int, old: Tensor, new: Tensor) -> None:
191+
rotated_old = functional.permute(old, shifts=-self.length + index + 1)
192+
self.value = functional.bind(self.value, rotated_old)
193+
194+
rotated_new = functional.permute(new, shifts=-self.length + index + 1)
195+
self.value = functional.bind(self.value, rotated_new)
116196

117197
def __len__(self) -> int:
118198
return self.length
119199

120200

121201
class Graph:
122-
def __init__(
123-
self, dimensions, threshold=0.5, directed=False, device=None, dtype=None
124-
):
202+
def __init__(self, dimensions, directed=False, device=None, dtype=None):
125203
self.length = 0
126-
self.threshold = threshold
204+
self.directed = directed
127205
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
128206
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
129-
self.directed = directed
130207

131-
def add_edge(self, node1: torch.Tensor, node2: torch.Tensor):
132-
if self.directed:
133-
edge = functional.bind(node1, node2)
134-
else:
135-
edge = functional.bind(node1, functional.permute(node2))
208+
def add_edge(self, node1: Tensor, node2: Tensor) -> None:
209+
edge = self.encode_edge(node1, node2)
136210
self.value = functional.bundle(self.value, edge)
137211

138-
def edge_exists(self, node1: torch.Tensor, node2: torch.Tensor):
212+
def encode_edge(self, node1: Tensor, node2: Tensor) -> Tensor:
139213
if self.directed:
140-
edge = functional.bind(node1, node2)
214+
return functional.bind(node1, node2)
141215
else:
142-
edge = functional.bind(node1, functional.permute(node2))
143-
return edge in self
216+
return functional.bind(node1, functional.permute(node2))
144217

145-
def node_neighbours(self, input: torch.Tensor):
146-
return functional.bind(self.value, input)
218+
def node_neighbors(self, input: Tensor, outgoing=True) -> Tensor:
219+
if self.directed:
220+
if outgoing:
221+
return functional.permute(functional.bind(self.value, input), shifts=-1)
222+
else:
223+
return functional.bind(self.value, functional.permute(input, shifts=1))
224+
else:
225+
return functional.bind(self.value, input)
147226

148-
def __contains__(self, input: torch.Tensor):
149-
sim = functional.cosine_similarity(input, self.value.unsqueeze(0))
150-
return sim.item() > self.threshold
227+
def contains(self, input: Tensor) -> Tensor:
228+
return functional.cosine_similarity(input, self.value.unsqueeze(0))
151229

152230

153231
class Tree:
154232
def __init__(self, dimensions, device=None, dtype=None):
233+
self.dimensions = dimensions
155234
self.dtype = dtype if dtype is not None else torch.get_default_dtype()
156235
self.value = torch.zeros(dimensions, dtype=dtype, device=device)
157-
self.l_r = functional.random_hv(2, dimensions)
236+
self.l_r = functional.random_hv(2, dimensions, dtype=dtype, device=device)
158237

159-
def add_leaf(self, value, path):
160-
for i in path:
238+
def add_leaf(self, value: Tensor, path: List[str]) -> None:
239+
for idx, i in enumerate(path):
161240
if i == "l":
162-
value = functional.bind(value, self.left)
241+
value = functional.bind(
242+
value, functional.permute(self.left, shifts=idx)
243+
)
163244
else:
164-
value = functional.bind(value, self.right)
245+
value = functional.bind(
246+
value, functional.permute(self.right, shifts=idx)
247+
)
248+
165249
self.value = functional.bundle(self.value, value)
166250

167251
@property
168-
def left(self):
252+
def left(self) -> Tensor:
169253
return self.l_r[0]
170254

171255
@property
172-
def right(self):
256+
def right(self) -> Tensor:
173257
return self.l_r[1]
174258

259+
def get_leaf(self, path: List[str]) -> Tensor:
260+
for idx, i in enumerate(path):
261+
if i == "l":
262+
if idx == 0:
263+
hv_path = self.left
264+
else:
265+
hv_path = functional.bind(
266+
hv_path, functional.permute(self.left, shifts=idx)
267+
)
268+
else:
269+
if idx == 0:
270+
hv_path = self.right
271+
else:
272+
hv_path = functional.bind(
273+
hv_path, functional.permute(self.right, shifts=idx)
274+
)
275+
276+
return functional.bind(hv_path, self.value)
277+
175278

176279
class FiniteStateAutomata:
177280
def __init__(self, dimensions, device=None, dtype=None):
@@ -180,18 +283,18 @@ def __init__(self, dimensions, device=None, dtype=None):
180283

181284
def add_transition(
182285
self,
183-
token: torch.Tensor,
184-
initial_state: torch.Tensor,
185-
final_state: torch.Tensor,
186-
):
286+
token: Tensor,
287+
initial_state: Tensor,
288+
final_state: Tensor,
289+
) -> None:
187290
transition_edge = functional.bind(
188291
initial_state, functional.permute(final_state)
189292
)
190293
transition = functional.bind(token, transition_edge)
191294
self.value = functional.bundle(self.value, transition)
192295

193-
def change_state(self, token: torch.Tensor, current_state: torch.Tensor):
296+
def transition(self, state: Tensor, action: Tensor) -> Tensor:
194297
# Returns the next state + some noise
195-
next_state = functional.bind(self.value, current_state)
196-
next_state = functional.bind(next_state, token)
298+
next_state = functional.bind(self.value, state)
299+
next_state = functional.bind(next_state, action)
197300
return functional.permute(next_state, shifts=-1)

0 commit comments

Comments
 (0)