Skip to content

Commit 6d2a0fa

Browse files
authored
Add binary hypervector support (#71)
* Add binary hypervector creation * Implement multiset and mulitbind * Fix bundle boolean implementation * Formatting * Remove out parameter * Add tie kwargs to bundle for boolean hvs * Add dtype testing * WIP testing binary hypervectors * WIP binary hypervector testing * WIP operation value tests * Test binary type encodings
1 parent c745b5c commit 6d2a0fa

File tree

10 files changed

+1081
-854
lines changed

10 files changed

+1081
-854
lines changed

torchhd/functional.py

Lines changed: 90 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import torch
3-
from torch import LongTensor, Tensor
3+
from torch import BoolTensor, LongTensor, Tensor
44
import torch.nn.functional as F
55

66
from collections import deque
@@ -64,14 +64,21 @@ def identity_hv(
6464
if dtype is None:
6565
dtype = torch.get_default_dtype()
6666

67-
if dtype in {torch.bool, torch.complex64, torch.complex128}:
68-
raise NotImplementedError(
69-
"Boolean, and Complex hypervectors are not supported yet."
70-
)
67+
if dtype in {torch.complex64, torch.complex128}:
68+
raise NotImplementedError("Complex hypervectors are not supported yet.")
7169

7270
if dtype == torch.uint8:
7371
raise ValueError("Unsigned integer hypervectors are not supported.")
7472

73+
if dtype == torch.bool:
74+
return torch.zeros(
75+
num_embeddings,
76+
embedding_dim,
77+
dtype=dtype,
78+
device=device,
79+
requires_grad=requires_grad,
80+
)
81+
7582
return torch.ones(
7683
num_embeddings,
7784
embedding_dim,
@@ -122,10 +129,8 @@ def random_hv(
122129
if dtype is None:
123130
dtype = torch.get_default_dtype()
124131

125-
if dtype in {torch.bool, torch.complex64, torch.complex128}:
126-
raise NotImplementedError(
127-
"Boolean, and Complex hypervectors are not supported yet."
128-
)
132+
if dtype in {torch.complex64, torch.complex128}:
133+
raise NotImplementedError("Complex hypervectors are not supported yet.")
129134

130135
if dtype == torch.uint8:
131136
raise ValueError("Unsigned integer hypervectors are not supported.")
@@ -137,6 +142,11 @@ def random_hv(
137142
),
138143
dtype=torch.bool,
139144
).bernoulli_(1.0 - sparsity, generator=generator)
145+
146+
if dtype == torch.bool:
147+
select.requires_grad = requires_grad
148+
return select
149+
140150
result = torch.where(select, -1, +1).to(dtype=dtype, device=device)
141151
result.requires_grad = requires_grad
142152
return result
@@ -183,10 +193,8 @@ def level_hv(
183193
if dtype is None:
184194
dtype = torch.get_default_dtype()
185195

186-
if dtype in {torch.bool, torch.complex64, torch.complex128}:
187-
raise NotImplementedError(
188-
"Boolean, and Complex hypervectors are not supported yet."
189-
)
196+
if dtype in {torch.complex64, torch.complex128}:
197+
raise NotImplementedError("Complex hypervectors are not supported yet.")
190198

191199
if dtype == torch.uint8:
192200
raise ValueError("Unsigned integer hypervectors are not supported.")
@@ -200,6 +208,8 @@ def level_hv(
200208

201209
# convert from normalized "randomness" variable r to number of orthogonal vectors sets "span"
202210
levels_per_span = (1 - randomness) * (num_embeddings - 1) + randomness * 1
211+
# must be at least one to deal with the case that num_embeddings is less than 2
212+
levels_per_span = max(levels_per_span, 1)
203213
span = (num_embeddings - 1) / levels_per_span
204214
# generate the set of orthogonal vectors within the level vector set
205215
span_hv = random_hv(
@@ -287,10 +297,8 @@ def circular_hv(
287297
if dtype is None:
288298
dtype = torch.get_default_dtype()
289299

290-
if dtype in {torch.bool, torch.complex64, torch.complex128}:
291-
raise NotImplementedError(
292-
"Boolean, and Complex hypervectors are not supported yet."
293-
)
300+
if dtype in {torch.complex64, torch.complex128}:
301+
raise NotImplementedError("Complex hypervectors are not supported yet.")
294302

295303
if dtype == torch.uint8:
296304
raise ValueError("Unsigned integer hypervectors are not supported.")
@@ -354,15 +362,15 @@ def circular_hv(
354362

355363
temp_hv = torch.where(threshold_v[span_idx] < t, span_start_hv, span_end_hv)
356364

357-
mutation_history.append(temp_hv * mutation_hv)
365+
mutation_history.append(bind(temp_hv, mutation_hv))
358366
mutation_hv = temp_hv
359367

360368
if i % 2 == 0:
361369
hv[i // 2] = mutation_hv
362370

363371
for i in range(num_embeddings + 1, num_embeddings * 2 - 1):
364372
mut = mutation_history.popleft()
365-
mutation_hv *= mut
373+
mutation_hv = bind(mutation_hv, mut)
366374

367375
if i % 2 == 0:
368376
hv[i // 2] = mutation_hv
@@ -371,7 +379,7 @@ def circular_hv(
371379
return hv
372380

373381

374-
def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
382+
def bind(input: Tensor, other: Tensor) -> Tensor:
375383
r"""Binds two hypervectors which produces a hypervector dissimilar to both.
376384
377385
Binding is used to associate information, for instance, to assign values to variables.
@@ -385,7 +393,6 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
385393
Args:
386394
input (Tensor): input hypervector
387395
other (Tensor): other input hypervector
388-
out (Tensor, optional): the output tensor.
389396
390397
Shapes:
391398
- Input: :math:`(*)`
@@ -402,18 +409,21 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
402409
tensor([ 1., -1., -1.])
403410
404411
"""
405-
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
406-
raise NotImplementedError(
407-
"Boolean, and Complex hypervectors are not supported yet."
408-
)
412+
dtype = input.dtype
409413

410-
if input.dtype == torch.uint8:
414+
if torch.is_complex(input):
415+
raise NotImplementedError("Complex hypervectors are not supported yet.")
416+
417+
if dtype == torch.uint8:
411418
raise ValueError("Unsigned integer hypervectors are not supported.")
412419

413-
return torch.mul(input, other, out=out)
420+
if dtype == torch.bool:
421+
return torch.logical_xor(input, other)
414422

423+
return torch.mul(input, other)
415424

416-
def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
425+
426+
def bundle(input: Tensor, other: Tensor, *, tie: BoolTensor = None) -> Tensor:
417427
r"""Bundles two hypervectors which produces a hypervector maximally similar to both.
418428
419429
The bundling operation is used to aggregate information into a single hypervector.
@@ -427,7 +437,7 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
427437
Args:
428438
input (Tensor): input hypervector
429439
other (Tensor): other input hypervector
430-
out (Tensor, optional): the output tensor.
440+
tie (BoolTensor, optional): specifies how to break a tie while bundling boolean hypervectors. Default: only set bit if both ``input`` and ``other`` are ``True``.
431441
432442
Shapes:
433443
- Input: :math:`(*)`
@@ -444,15 +454,21 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
444454
tensor([0., 2., 0.])
445455
446456
"""
447-
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
448-
raise NotImplementedError(
449-
"Boolean, and Complex hypervectors are not supported yet."
450-
)
457+
dtype = input.dtype
451458

452-
if input.dtype == torch.uint8:
459+
if torch.is_complex(input):
460+
raise NotImplementedError("Complex hypervectors are not supported yet.")
461+
462+
if dtype == torch.uint8:
453463
raise ValueError("Unsigned integer hypervectors are not supported.")
454464

455-
return torch.add(input, other, out=out)
465+
if dtype == torch.bool:
466+
if tie is not None:
467+
return torch.where(input == other, input, tie)
468+
else:
469+
return torch.logical_and(input, other)
470+
471+
return torch.add(input, other)
456472

457473

458474
def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
@@ -484,15 +500,19 @@ def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
484500
tensor([ -1., 1., -1.])
485501
486502
"""
503+
dtype = input.dtype
504+
505+
if dtype == torch.uint8:
506+
raise ValueError("Unsigned integer hypervectors are not supported.")
507+
487508
return torch.roll(input, shifts=shifts, dims=dims)
488509

489510

490-
def soft_quantize(input: Tensor, *, out=None):
511+
def soft_quantize(input: Tensor):
491512
"""Applies the hyperbolic tanh function to all elements of the input tensor.
492513
493514
Args:
494515
input (Tensor): input tensor.
495-
out (Tensor, optional): output tensor. Defaults to None.
496516
497517
Shapes:
498518
- Input: :math:`(*)`
@@ -508,15 +528,14 @@ def soft_quantize(input: Tensor, *, out=None):
508528
tensor([0.0000, 0.9640, 0.0000])
509529
510530
"""
511-
return torch.tanh(input, out=out)
531+
return torch.tanh(input)
512532

513533

514-
def hard_quantize(input: Tensor, *, out=None):
534+
def hard_quantize(input: Tensor):
515535
"""Applies binary quantization to all elements of the input tensor.
516536
517537
Args:
518538
input (Tensor): input tensor
519-
out (Tensor, optional): output tensor. Defaults to None.
520539
521540
Shapes:
522541
- Input: :math:`(*)`
@@ -537,13 +556,7 @@ def hard_quantize(input: Tensor, *, out=None):
537556
positive = torch.tensor(1.0, dtype=input.dtype, device=input.device)
538557
negative = torch.tensor(-1.0, dtype=input.dtype, device=input.device)
539558

540-
if out != None:
541-
out[:] = torch.where(input > 0, positive, negative)
542-
result = out
543-
else:
544-
result = torch.where(input > 0, positive, negative)
545-
546-
return result
559+
return torch.where(input > 0, positive, negative)
547560

548561

549562
def cosine_similarity(input: Tensor, others: Tensor) -> Tensor:
@@ -650,16 +663,21 @@ def multiset(input: Tensor) -> Tensor:
650663
tensor([-1., 3., 1.])
651664
652665
"""
666+
dim = -2
667+
dtype = input.dtype
653668

654-
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
655-
raise NotImplementedError(
656-
"Boolean, and Complex hypervectors are not supported yet."
657-
)
669+
if dtype in {torch.complex64, torch.complex128}:
670+
raise NotImplementedError("Complex hypervectors are not supported yet.")
658671

659-
if input.dtype == torch.uint8:
672+
if dtype == torch.uint8:
660673
raise ValueError("Unsigned integer hypervectors are not supported.")
661674

662-
return torch.sum(input, dim=-2, dtype=input.dtype)
675+
if dtype == torch.bool:
676+
count = torch.sum(input, dim=dim, dtype=torch.long)
677+
threshold = input.size(dim) // 2
678+
return torch.greater(count, threshold)
679+
680+
return torch.sum(input, dim=dim, dtype=dtype)
663681

664682

665683
multibundle = multiset
@@ -681,6 +699,10 @@ def multibind(input: Tensor) -> Tensor:
681699
- Input: :math:`(*, n, d)`
682700
- Output: :math:`(*, d)`
683701
702+
.. note::
703+
704+
This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.
705+
684706
Examples::
685707
686708
>>> x = functional.random_hv(3, 3)
@@ -692,14 +714,21 @@ def multibind(input: Tensor) -> Tensor:
692714
tensor([ 1., 1., -1.])
693715
694716
"""
695-
if input.dtype in {torch.bool, torch.complex64, torch.complex128}:
696-
raise NotImplementedError(
697-
"Boolean, and Complex hypervectors are not supported yet."
698-
)
717+
if input.dtype in {torch.complex64, torch.complex128}:
718+
raise NotImplementedError("Complex hypervectors are not supported yet.")
699719

700720
if input.dtype == torch.uint8:
701721
raise ValueError("Unsigned integer hypervectors are not supported.")
702722

723+
if input.dtype == torch.bool:
724+
hvs = torch.unbind(input, -2)
725+
result = hvs[0]
726+
727+
for i in range(1, len(hvs)):
728+
result = torch.logical_xor(result, hvs[i])
729+
730+
return result
731+
703732
return torch.prod(input, dim=-2, dtype=input.dtype)
704733

705734

@@ -870,6 +899,10 @@ def bind_sequence(input: Tensor) -> Tensor:
870899
- Input: :math:`(*, n, d)`
871900
- Output: :math:`(*, d)`
872901
902+
.. note::
903+
904+
This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.
905+
873906
Examples::
874907
875908
>>> x = functional.random_hv(5, 3)

torchhd/tests/basis_hv/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)