11import math
22import torch
3- from torch import LongTensor , Tensor
3+ from torch import BoolTensor , LongTensor , Tensor
44import torch .nn .functional as F
55
66from collections import deque
@@ -64,14 +64,21 @@ def identity_hv(
6464 if dtype is None :
6565 dtype = torch .get_default_dtype ()
6666
67- if dtype in {torch .bool , torch .complex64 , torch .complex128 }:
68- raise NotImplementedError (
69- "Boolean, and Complex hypervectors are not supported yet."
70- )
67+ if dtype in {torch .complex64 , torch .complex128 }:
68+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
7169
7270 if dtype == torch .uint8 :
7371 raise ValueError ("Unsigned integer hypervectors are not supported." )
7472
73+ if dtype == torch .bool :
74+ return torch .zeros (
75+ num_embeddings ,
76+ embedding_dim ,
77+ dtype = dtype ,
78+ device = device ,
79+ requires_grad = requires_grad ,
80+ )
81+
7582 return torch .ones (
7683 num_embeddings ,
7784 embedding_dim ,
@@ -122,10 +129,8 @@ def random_hv(
122129 if dtype is None :
123130 dtype = torch .get_default_dtype ()
124131
125- if dtype in {torch .bool , torch .complex64 , torch .complex128 }:
126- raise NotImplementedError (
127- "Boolean, and Complex hypervectors are not supported yet."
128- )
132+ if dtype in {torch .complex64 , torch .complex128 }:
133+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
129134
130135 if dtype == torch .uint8 :
131136 raise ValueError ("Unsigned integer hypervectors are not supported." )
@@ -137,6 +142,11 @@ def random_hv(
137142 ),
138143 dtype = torch .bool ,
139144 ).bernoulli_ (1.0 - sparsity , generator = generator )
145+
146+ if dtype == torch .bool :
147+ select .requires_grad = requires_grad
148+ return select
149+
140150 result = torch .where (select , - 1 , + 1 ).to (dtype = dtype , device = device )
141151 result .requires_grad = requires_grad
142152 return result
@@ -183,10 +193,8 @@ def level_hv(
183193 if dtype is None :
184194 dtype = torch .get_default_dtype ()
185195
186- if dtype in {torch .bool , torch .complex64 , torch .complex128 }:
187- raise NotImplementedError (
188- "Boolean, and Complex hypervectors are not supported yet."
189- )
196+ if dtype in {torch .complex64 , torch .complex128 }:
197+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
190198
191199 if dtype == torch .uint8 :
192200 raise ValueError ("Unsigned integer hypervectors are not supported." )
@@ -200,6 +208,8 @@ def level_hv(
200208
201209 # convert from normalized "randomness" variable r to number of orthogonal vectors sets "span"
202210 levels_per_span = (1 - randomness ) * (num_embeddings - 1 ) + randomness * 1
211+ # must be at least one to deal with the case that num_embeddings is less than 2
212+ levels_per_span = max (levels_per_span , 1 )
203213 span = (num_embeddings - 1 ) / levels_per_span
204214 # generate the set of orthogonal vectors within the level vector set
205215 span_hv = random_hv (
@@ -287,10 +297,8 @@ def circular_hv(
287297 if dtype is None :
288298 dtype = torch .get_default_dtype ()
289299
290- if dtype in {torch .bool , torch .complex64 , torch .complex128 }:
291- raise NotImplementedError (
292- "Boolean, and Complex hypervectors are not supported yet."
293- )
300+ if dtype in {torch .complex64 , torch .complex128 }:
301+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
294302
295303 if dtype == torch .uint8 :
296304 raise ValueError ("Unsigned integer hypervectors are not supported." )
@@ -354,15 +362,15 @@ def circular_hv(
354362
355363 temp_hv = torch .where (threshold_v [span_idx ] < t , span_start_hv , span_end_hv )
356364
357- mutation_history .append (temp_hv * mutation_hv )
365+ mutation_history .append (bind ( temp_hv , mutation_hv ) )
358366 mutation_hv = temp_hv
359367
360368 if i % 2 == 0 :
361369 hv [i // 2 ] = mutation_hv
362370
363371 for i in range (num_embeddings + 1 , num_embeddings * 2 - 1 ):
364372 mut = mutation_history .popleft ()
365- mutation_hv *= mut
373+ mutation_hv = bind ( mutation_hv , mut )
366374
367375 if i % 2 == 0 :
368376 hv [i // 2 ] = mutation_hv
@@ -371,7 +379,7 @@ def circular_hv(
371379 return hv
372380
373381
374- def bind (input : Tensor , other : Tensor , * , out = None ) -> Tensor :
382+ def bind (input : Tensor , other : Tensor ) -> Tensor :
375383 r"""Binds two hypervectors which produces a hypervector dissimilar to both.
376384
377385 Binding is used to associate information, for instance, to assign values to variables.
@@ -385,7 +393,6 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
385393 Args:
386394 input (Tensor): input hypervector
387395 other (Tensor): other input hypervector
388- out (Tensor, optional): the output tensor.
389396
390397 Shapes:
391398 - Input: :math:`(*)`
@@ -402,18 +409,21 @@ def bind(input: Tensor, other: Tensor, *, out=None) -> Tensor:
402409 tensor([ 1., -1., -1.])
403410
404411 """
405- if input .dtype in {torch .bool , torch .complex64 , torch .complex128 }:
406- raise NotImplementedError (
407- "Boolean, and Complex hypervectors are not supported yet."
408- )
412+ dtype = input .dtype
409413
410- if input .dtype == torch .uint8 :
414+ if torch .is_complex (input ):
415+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
416+
417+ if dtype == torch .uint8 :
411418 raise ValueError ("Unsigned integer hypervectors are not supported." )
412419
413- return torch .mul (input , other , out = out )
420+ if dtype == torch .bool :
421+ return torch .logical_xor (input , other )
414422
423+ return torch .mul (input , other )
415424
416- def bundle (input : Tensor , other : Tensor , * , out = None ) -> Tensor :
425+
426+ def bundle (input : Tensor , other : Tensor , * , tie : BoolTensor = None ) -> Tensor :
417427 r"""Bundles two hypervectors which produces a hypervector maximally similar to both.
418428
419429 The bundling operation is used to aggregate information into a single hypervector.
@@ -427,7 +437,7 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
427437 Args:
428438 input (Tensor): input hypervector
429439 other (Tensor): other input hypervector
430- out (Tensor , optional): the output tensor .
440+ tie (BoolTensor , optional): specifies how to break a tie while bundling boolean hypervectors. Default: only set bit if both ``input`` and ``other`` are ``True`` .
431441
432442 Shapes:
433443 - Input: :math:`(*)`
@@ -444,15 +454,21 @@ def bundle(input: Tensor, other: Tensor, *, out=None) -> Tensor:
444454 tensor([0., 2., 0.])
445455
446456 """
447- if input .dtype in {torch .bool , torch .complex64 , torch .complex128 }:
448- raise NotImplementedError (
449- "Boolean, and Complex hypervectors are not supported yet."
450- )
457+ dtype = input .dtype
451458
452- if input .dtype == torch .uint8 :
459+ if torch .is_complex (input ):
460+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
461+
462+ if dtype == torch .uint8 :
453463 raise ValueError ("Unsigned integer hypervectors are not supported." )
454464
455- return torch .add (input , other , out = out )
465+ if dtype == torch .bool :
466+ if tie is not None :
467+ return torch .where (input == other , input , tie )
468+ else :
469+ return torch .logical_and (input , other )
470+
471+ return torch .add (input , other )
456472
457473
458474def permute (input : Tensor , * , shifts = 1 , dims = - 1 ) -> Tensor :
@@ -484,15 +500,19 @@ def permute(input: Tensor, *, shifts=1, dims=-1) -> Tensor:
484500 tensor([ -1., 1., -1.])
485501
486502 """
503+ dtype = input .dtype
504+
505+ if dtype == torch .uint8 :
506+ raise ValueError ("Unsigned integer hypervectors are not supported." )
507+
487508 return torch .roll (input , shifts = shifts , dims = dims )
488509
489510
490- def soft_quantize (input : Tensor , * , out = None ):
511+ def soft_quantize (input : Tensor ):
491512 """Applies the hyperbolic tanh function to all elements of the input tensor.
492513
493514 Args:
494515 input (Tensor): input tensor.
495- out (Tensor, optional): output tensor. Defaults to None.
496516
497517 Shapes:
498518 - Input: :math:`(*)`
@@ -508,15 +528,14 @@ def soft_quantize(input: Tensor, *, out=None):
508528 tensor([0.0000, 0.9640, 0.0000])
509529
510530 """
511- return torch .tanh (input , out = out )
531+ return torch .tanh (input )
512532
513533
514- def hard_quantize (input : Tensor , * , out = None ):
534+ def hard_quantize (input : Tensor ):
515535 """Applies binary quantization to all elements of the input tensor.
516536
517537 Args:
518538 input (Tensor): input tensor
519- out (Tensor, optional): output tensor. Defaults to None.
520539
521540 Shapes:
522541 - Input: :math:`(*)`
@@ -537,13 +556,7 @@ def hard_quantize(input: Tensor, *, out=None):
537556 positive = torch .tensor (1.0 , dtype = input .dtype , device = input .device )
538557 negative = torch .tensor (- 1.0 , dtype = input .dtype , device = input .device )
539558
540- if out != None :
541- out [:] = torch .where (input > 0 , positive , negative )
542- result = out
543- else :
544- result = torch .where (input > 0 , positive , negative )
545-
546- return result
559+ return torch .where (input > 0 , positive , negative )
547560
548561
549562def cosine_similarity (input : Tensor , others : Tensor ) -> Tensor :
@@ -650,16 +663,21 @@ def multiset(input: Tensor) -> Tensor:
650663 tensor([-1., 3., 1.])
651664
652665 """
666+ dim = - 2
667+ dtype = input .dtype
653668
654- if input .dtype in {torch .bool , torch .complex64 , torch .complex128 }:
655- raise NotImplementedError (
656- "Boolean, and Complex hypervectors are not supported yet."
657- )
669+ if dtype in {torch .complex64 , torch .complex128 }:
670+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
658671
659- if input . dtype == torch .uint8 :
672+ if dtype == torch .uint8 :
660673 raise ValueError ("Unsigned integer hypervectors are not supported." )
661674
662- return torch .sum (input , dim = - 2 , dtype = input .dtype )
675+ if dtype == torch .bool :
676+ count = torch .sum (input , dim = dim , dtype = torch .long )
677+ threshold = input .size (dim ) // 2
678+ return torch .greater (count , threshold )
679+
680+ return torch .sum (input , dim = dim , dtype = dtype )
663681
664682
665683multibundle = multiset
@@ -681,6 +699,10 @@ def multibind(input: Tensor) -> Tensor:
681699 - Input: :math:`(*, n, d)`
682700 - Output: :math:`(*, d)`
683701
702+ .. note::
703+
704+ This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.
705+
684706 Examples::
685707
686708 >>> x = functional.random_hv(3, 3)
@@ -692,14 +714,21 @@ def multibind(input: Tensor) -> Tensor:
692714 tensor([ 1., 1., -1.])
693715
694716 """
695- if input .dtype in {torch .bool , torch .complex64 , torch .complex128 }:
696- raise NotImplementedError (
697- "Boolean, and Complex hypervectors are not supported yet."
698- )
717+ if input .dtype in {torch .complex64 , torch .complex128 }:
718+ raise NotImplementedError ("Complex hypervectors are not supported yet." )
699719
700720 if input .dtype == torch .uint8 :
701721 raise ValueError ("Unsigned integer hypervectors are not supported." )
702722
723+ if input .dtype == torch .bool :
724+ hvs = torch .unbind (input , - 2 )
725+ result = hvs [0 ]
726+
727+ for i in range (1 , len (hvs )):
728+ result = torch .logical_xor (result , hvs [i ])
729+
730+ return result
731+
703732 return torch .prod (input , dim = - 2 , dtype = input .dtype )
704733
705734
@@ -870,6 +899,10 @@ def bind_sequence(input: Tensor) -> Tensor:
870899 - Input: :math:`(*, n, d)`
871900 - Output: :math:`(*, d)`
872901
902+ .. note::
903+
904+ This method is not supported for ``torch.float16`` and ``torch.bfloat16`` input data types on a CPU device.
905+
873906 Examples::
874907
875908 >>> x = functional.random_hv(5, 3)
0 commit comments