22import torch
33from torch import BoolTensor , LongTensor , Tensor
44import torch .nn .functional as F
5-
65from collections import deque
76
87
@@ -688,6 +687,8 @@ def hard_quantize(input: Tensor):
688687def dot_similarity (input : Tensor , others : Tensor ) -> Tensor :
689688 """Dot product between the input vector and each vector in others.
690689
690+ Aliased as ``torchhd.dot_similarity``.
691+
691692 Args:
692693 input (Tensor): hypervectors to compare against others
693694 others (Tensor): hypervectors to compare with
@@ -697,6 +698,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
697698 - Others: :math:`(n, d)` or :math:`(d)`
698699 - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
699700
701+ .. note::
702+
703+ Output ``dtype`` for ``torch.bool`` is ``torch.long``,
704+ for ``torch.complex64`` is ``torch.float``,
705+ for ``torch.complex128`` is ``torch.double``, otherwise same as input ``dtype``.
706+
700707 Examples::
701708
702709 >>> x = functional.random_hv(3, 6)
@@ -720,6 +727,12 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
720727 [ 0.6771, -4.2506, 6.0000]])
721728
722729 """
730+ if input .dtype == torch .bool :
731+ input_as_bipolar = torch .where (input , - 1 , 1 )
732+ others_as_bipolar = torch .where (others , - 1 , 1 )
733+
734+ return F .linear (input_as_bipolar , others_as_bipolar )
735+
723736 if torch .is_complex (input ):
724737 return F .linear (input , others .conj ()).real
725738
@@ -729,6 +742,8 @@ def dot_similarity(input: Tensor, others: Tensor) -> Tensor:
729742def cosine_similarity (input : Tensor , others : Tensor , * , eps = 1e-08 ) -> Tensor :
730743 """Cosine similarity between the input vector and each vector in others.
731744
745+ Aliased as ``torchhd.cosine_similarity``.
746+
732747 Args:
733748 input (Tensor): hypervectors to compare against others
734749 others (Tensor): hypervectors to compare with
@@ -738,6 +753,10 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
738753 - Others: :math:`(n, d)` or :math:`(d)`
739754 - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
740755
756+ .. note::
757+
758+ Output ``dtype`` is ``torch.get_default_dtype()``.
759+
741760 Examples::
742761
743762 >>> x = functional.random_hv(3, 6)
@@ -761,43 +780,75 @@ def cosine_similarity(input: Tensor, others: Tensor, *, eps=1e-08) -> Tensor:
761780 [0.1806, 0.2607, 1.0000]])
762781
763782 """
764- if torch .is_complex (input ):
765- input_mag = torch .real (input * input .conj ()).sum (dim = - 1 ).sqrt ()
766- others_mag = torch .real (others * others .conj ()).sum (dim = - 1 ).sqrt ()
783+ out_dtype = torch .get_default_dtype ()
784+
785+ # calculate vector magnitude
786+ if input .dtype == torch .bool :
787+ input_mag = torch .full (
788+ input .shape [:- 1 ],
789+ math .sqrt (input .size (- 1 )),
790+ dtype = out_dtype ,
791+ device = input .device ,
792+ )
793+ others_mag = torch .full (
794+ others .shape [:- 1 ],
795+ math .sqrt (others .size (- 1 )),
796+ dtype = out_dtype ,
797+ device = others .device ,
798+ )
799+
800+ elif torch .is_complex (input ):
801+ input_dot = torch .real (input * input .conj ()).sum (dim = - 1 , dtype = out_dtype )
802+ input_mag = input_dot .sqrt ()
803+
804+ others_dot = torch .real (others * others .conj ()).sum (dim = - 1 , dtype = out_dtype )
805+ others_mag = others_dot .sqrt ()
806+
767807 else :
768- input_mag = torch .sum (input * input , dim = - 1 ).sqrt ()
769- others_mag = torch .sum (others * others , dim = - 1 ).sqrt ()
808+ input_dot = torch .sum (input * input , dim = - 1 , dtype = out_dtype )
809+ input_mag = input_dot .sqrt ()
810+
811+ others_dot = torch .sum (others * others , dim = - 1 , dtype = out_dtype )
812+ others_mag = others_dot .sqrt ()
770813
771814 if input .dim () > 1 :
772815 magnitude = input_mag .unsqueeze (- 1 ) * others_mag .unsqueeze (0 )
773816 else :
774817 magnitude = input_mag * others_mag
775818
776- return dot_similarity (input , others ) / (magnitude + eps )
819+ return dot_similarity (input , others ). to ( out_dtype ) / (magnitude + eps )
777820
778821
779822def hamming_similarity (input : Tensor , others : Tensor ) -> LongTensor :
780- """Number of equal elements between the input vector and each vector in others.
823+ """Number of equal elements between the input vectors and each vector in others.
781824
782825 Args:
783- input (Tensor): one-dimensional tensor
784- others (Tensor): two-dimensional tensor
826+ input (Tensor): hypervectors to compare against others
827+ others (Tensor): hypervectors to compare with
785828
786829 Shapes:
787- - Input: :math:`(d)`
788- - Others: :math:`(n, d)`
789- - Output: :math:`(n)`
830+ - Input: :math:`(*, d)`
831+ - Others: :math:`(n, d)` or :math:`(d)`
832+ - Output: :math:`(*, n)` or :math:`(*)`, depends on shape of others
790833
791834 Examples::
792835
793- >>> x = functional.random_hv(2, 3 )
836+ >>> x = functional.random_hv(3, 6 )
794837 >>> x
795- tensor([[ 1., 1., -1.],
796- [-1., -1., -1.]])
797- >>> functional.hamming_similarity(x[0], x)
798- tensor([3., 1.])
838+ tensor([[ 1., 1., -1., -1., 1., 1.],
839+ [ 1., 1., 1., 1., -1., -1.],
840+ [ 1., 1., -1., -1., -1., 1.]])
841+ >>> functional.hamming_similarity(x, x)
842+ tensor([[6, 2, 5],
843+ [2, 6, 3],
844+ [5, 3, 6]])
799845
800846 """
847+ if input .dim () > 1 and others .dim () > 1 :
848+ return torch .sum (
849+ input .unsqueeze (- 2 ) == others .unsqueeze (- 3 ), dim = - 1 , dtype = torch .long
850+ )
851+
801852 return torch .sum (input == others , dim = - 1 , dtype = torch .long )
802853
803854
0 commit comments