1- from typing import Any , List , Optional , Tuple
1+ from typing import Any , List , Optional , Tuple , overload
22import torch
3+ from torch import Tensor
34
45import torchhd .functional as functional
56
@@ -9,19 +10,19 @@ class Memory:
910
1011 def __init__ (self , threshold = 0.5 ):
1112 self .threshold = threshold
12- self .keys : List [torch . Tensor ] = []
13+ self .keys : List [Tensor ] = []
1314 self .values : List [Any ] = []
1415
1516 def __len__ (self ) -> int :
1617 """Returns the number of items in memory"""
1718 return len (self .values )
1819
19- def add (self , key : torch . Tensor , value : Any ) -> None :
20+ def add (self , key : Tensor , value : Any ) -> None :
2021 """Adds one (key, value) pair to memory"""
2122 self .keys .append (key )
2223 self .values .append (value )
2324
24- def _get_index (self , key : torch . Tensor ) -> int :
25+ def _get_index (self , key : Tensor ) -> int :
2526 key_stack = torch .stack (self .keys , dim = 0 )
2627 sim = functional .cosine_similarity (key , key_stack )
2728 value , index = torch .max (sim , 0 )
@@ -31,147 +32,249 @@ def _get_index(self, key: torch.Tensor) -> int:
3132
3233 return index
3334
34- def __getitem__ (self , key : torch . Tensor ) -> Tuple [torch . Tensor , Any ]:
35+ def __getitem__ (self , key : Tensor ) -> Tuple [Tensor , Any ]:
3536 """Get the (key, value) pair with an approximate key"""
3637 index = self ._get_index (key )
3738 return self .keys [index ], self .values [index ]
3839
39- def __setitem__ (self , key : torch . Tensor , value : Any ) -> None :
40+ def __setitem__ (self , key : Tensor , value : Any ) -> None :
4041 """Set the value of an (key, value) pair with an approximate key"""
4142 index = self ._get_index (key )
4243 self .values [index ] = value
4344
44- def __delitem__ (self , key : torch . Tensor ) -> None :
45+ def __delitem__ (self , key : Tensor ) -> None :
4546 """Delete the (key, value) pair with an approximate key"""
4647 index = self ._get_index (key )
4748 del self .keys [index ]
4849 del self .values [index ]
4950
5051
5152class Multiset :
52- def __init__ (self , dimensions , threshold = 0.5 , device = None , dtype = None ):
53- self .threshold = threshold
54- self .cardinality = 0
55- dtype = dtype if dtype is not None else torch .get_default_dtype ()
56- self .value = torch .zeros (dimensions , dtype = dtype , device = device )
53+ @overload
54+ def __init__ (self , dimensions : int , * , device = None , dtype = None ):
55+ ...
56+
57+ @overload
58+ def __init__ (self , input : Tensor , * , size = 0 ):
59+ ...
60+
61+ def __init__ (self , dim_or_input : int , ** kwargs ):
62+ self .size = kwargs .get ("size" , 0 )
63+ if torch .is_tensor (dim_or_input ):
64+ self .value = dim_or_input
65+ else :
66+ dtype = kwargs .get ("dtype" , torch .get_default_dtype ())
67+ device = kwargs .get ("device" , None )
68+ self .value = torch .zeros (dim_or_input , dtype = dtype , device = device )
5769
58- def add (self , input : torch . Tensor ) -> None :
70+ def add (self , input : Tensor ) -> None :
5971 self .value = functional .bundle (self .value , input )
60- self .cardinality += 1
72+ self .size += 1
6173
62- def remove (self , input : torch .Tensor ) -> None :
63- if input not in self :
64- return
74+ def remove (self , input : Tensor ) -> None :
6575 self .value = functional .bundle (self .value , - input )
66- self .cardinality -= 1
76+ self .size -= 1
6777
68- def __contains__ (self , input : torch .Tensor ):
69- sim = functional .cosine_similarity (input , self .values .unsqueeze (0 ))
70- return sim .item () > self .threshold
78+ def contains (self , input : Tensor ) -> Tensor :
79+ return functional .cosine_similarity (input , self .value .unsqueeze (0 ))
7180
7281 def __len__ (self ) -> int :
73- return self .cardinality
82+ return self .size
7483
7584 @classmethod
76- def from_ngrams (cls , input : torch .Tensor , n = 3 , threshold = 0.5 ):
77- instance = cls (input .size (- 1 ), threshold , input .device , input .dtype )
78- instance .value = functional .ngrams (input , n )
79- return instance
85+ def from_ngrams (cls , input : Tensor , n = 3 ):
86+ value = functional .ngrams (input , n )
87+ return cls (value , size = input .size (- 2 ) - n + 1 )
8088
8189 @classmethod
82- def from_tensors (cls , input : torch .Tensor , dim = - 2 , threshold = 0.5 ):
83- instance = cls (input .size (- 1 ), threshold , input .device , input .dtype )
84- instance .value = functional .multiset (input = input , dim = dim )
85- return instance
90+ def from_tensor (cls , input : Tensor ):
91+ value = functional .multiset (input , dim = - 2 )
92+ return cls (value , size = input .size (- 2 ))
8693
8794
8895class Sequence :
89- def __init__ (self , dimensions , threshold = 0.5 , device = None , dtype = None ):
90- self .length = 0
91- self .threshold = threshold
92- dtype = dtype if dtype is not None else torch .get_default_dtype ()
93- self .value = torch .zeros (dimensions , dtype = dtype , device = device )
96+ @overload
97+ def __init__ (self , dimensions : int , * , device = None , dtype = None ):
98+ ...
99+
100+ @overload
101+ def __init__ (self , input : Tensor , * , length = 0 ):
102+ ...
103+
104+ def __init__ (self , dim_or_input : int , ** kwargs ):
105+ self .length = kwargs .get ("length" , 0 )
106+ if torch .is_tensor (dim_or_input ):
107+ self .value = dim_or_input
108+ else :
109+ dtype = kwargs .get ("dtype" , torch .get_default_dtype ())
110+ device = kwargs .get ("device" , None )
111+ self .value = torch .zeros (dim_or_input , dtype = dtype , device = device )
94112
95- def append (self , input : torch . Tensor ) -> None :
113+ def append (self , input : Tensor ) -> None :
96114 rotated_value = functional .permute (self .value , shifts = 1 )
97115 self .value = functional .bundle (input , rotated_value )
116+ self .length += 1
98117
99- def appendleft (self , input : torch . Tensor ) -> None :
118+ def appendleft (self , input : Tensor ) -> None :
100119 rotated_input = functional .permute (input , shifts = len (self ))
101120 self .value = functional .bundle (self .value , rotated_input )
121+ self .length += 1
102122
103- def pop (self , input : torch .Tensor ) -> Optional [torch .Tensor ]:
123+ def pop (self , input : Tensor ) -> None :
124+ self .length -= 1
104125 self .value = functional .bundle (self .value , - input )
105126 self .value = functional .permute (self .value , shifts = - 1 )
106- self .length -= 1
107127
108- def popleft (self , input : torch .Tensor ) -> None :
109- rotated_input = functional .permute (input , shifts = len (self ) + 1 )
128+ def popleft (self , input : Tensor ) -> None :
129+ self .length -= 1
130+ rotated_input = functional .permute (input , shifts = len (self ))
110131 self .value = functional .bundle (self .value , - rotated_input )
132+
133+ def replace (self , index : int , old : Tensor , new : Tensor ) -> None :
134+ rotated_old = functional .permute (old , shifts = - self .length + index + 1 )
135+ self .value = functional .bundle (self .value , - rotated_old )
136+
137+ rotated_new = functional .permute (new , shifts = - self .length + index + 1 )
138+ self .value = functional .bundle (self .value , rotated_new )
139+
140+ def concat (self , seq : "Sequence" ) -> "Sequence" :
141+ value = functional .permute (self .value , shifts = len (seq ))
142+ value = functional .bundle (value , seq .value )
143+ return Sequence (value , length = len (self ) + len (seq ))
144+
145+ def __getitem__ (self , index : int ) -> Tensor :
146+ return functional .permute (self .value , shifts = - self .length + index + 1 )
147+
148+ def __len__ (self ) -> int :
149+ return self .length
150+
151+
152+ class DistinctSequence :
153+ @overload
154+ def __init__ (self , dimensions : int , * , device = None , dtype = None ):
155+ ...
156+
157+ @overload
158+ def __init__ (self , input : Tensor , * , length = 0 ):
159+ ...
160+
161+ def __init__ (self , dim_or_input : int , ** kwargs ):
162+ self .length = kwargs .get ("length" , 0 )
163+ if torch .is_tensor (dim_or_input ):
164+ self .value = dim_or_input
165+ else :
166+ dtype = kwargs .get ("dtype" , torch .get_default_dtype ())
167+ device = kwargs .get ("device" , None )
168+ self .value = torch .zeros (dim_or_input , dtype = dtype , device = device )
169+
170+ def append (self , input : Tensor ) -> None :
171+ rotated_value = functional .permute (self .value , shifts = 1 )
172+ self .value = functional .bind (input , rotated_value )
173+ self .length += 1
174+
175+ def appendleft (self , input : Tensor ) -> None :
176+ rotated_input = functional .permute (input , shifts = len (self ))
177+ self .value = functional .bind (self .value , rotated_input )
178+ self .length += 1
179+
180+ def pop (self , input : Tensor ) -> None :
181+ self .length -= 1
182+ self .value = functional .bind (self .value , input )
183+ self .value = functional .permute (self .value , shifts = - 1 )
184+
185+ def popleft (self , input : Tensor ) -> None :
111186 self .length -= 1
187+ rotated_input = functional .permute (input , shifts = len (self ))
188+ self .value = functional .bind (self .value , rotated_input )
112189
113- def __getitem__ (self , index : int ) -> torch .Tensor :
114- rotated_value = functional .permute (self .value , shifts = - index )
115- return rotated_value
190+ def replace (self , index : int , old : Tensor , new : Tensor ) -> None :
191+ rotated_old = functional .permute (old , shifts = - self .length + index + 1 )
192+ self .value = functional .bind (self .value , rotated_old )
193+
194+ rotated_new = functional .permute (new , shifts = - self .length + index + 1 )
195+ self .value = functional .bind (self .value , rotated_new )
116196
117197 def __len__ (self ) -> int :
118198 return self .length
119199
120200
121201class Graph :
122- def __init__ (
123- self , dimensions , threshold = 0.5 , directed = False , device = None , dtype = None
124- ):
202+ def __init__ (self , dimensions , directed = False , device = None , dtype = None ):
125203 self .length = 0
126- self .threshold = threshold
204+ self .directed = directed
127205 self .dtype = dtype if dtype is not None else torch .get_default_dtype ()
128206 self .value = torch .zeros (dimensions , dtype = dtype , device = device )
129- self .directed = directed
130207
131- def add_edge (self , node1 : torch .Tensor , node2 : torch .Tensor ):
132- if self .directed :
133- edge = functional .bind (node1 , node2 )
134- else :
135- edge = functional .bind (node1 , functional .permute (node2 ))
208+ def add_edge (self , node1 : Tensor , node2 : Tensor ) -> None :
209+ edge = self .encode_edge (node1 , node2 )
136210 self .value = functional .bundle (self .value , edge )
137211
138- def edge_exists (self , node1 : torch . Tensor , node2 : torch . Tensor ):
212+ def encode_edge (self , node1 : Tensor , node2 : Tensor ) -> Tensor :
139213 if self .directed :
140- edge = functional .bind (node1 , node2 )
214+ return functional .bind (node1 , node2 )
141215 else :
142- edge = functional .bind (node1 , functional .permute (node2 ))
143- return edge in self
216+ return functional .bind (node1 , functional .permute (node2 ))
144217
145- def node_neighbours (self , input : torch .Tensor ):
146- return functional .bind (self .value , input )
218+ def node_neighbors (self , input : Tensor , outgoing = True ) -> Tensor :
219+ if self .directed :
220+ if outgoing :
221+ return functional .permute (functional .bind (self .value , input ), shifts = - 1 )
222+ else :
223+ return functional .bind (self .value , functional .permute (input , shifts = 1 ))
224+ else :
225+ return functional .bind (self .value , input )
147226
148- def __contains__ (self , input : torch .Tensor ):
149- sim = functional .cosine_similarity (input , self .value .unsqueeze (0 ))
150- return sim .item () > self .threshold
227+ def contains (self , input : Tensor ) -> Tensor :
228+ return functional .cosine_similarity (input , self .value .unsqueeze (0 ))
151229
152230
153231class Tree :
154232 def __init__ (self , dimensions , device = None , dtype = None ):
233+ self .dimensions = dimensions
155234 self .dtype = dtype if dtype is not None else torch .get_default_dtype ()
156235 self .value = torch .zeros (dimensions , dtype = dtype , device = device )
157- self .l_r = functional .random_hv (2 , dimensions )
236+ self .l_r = functional .random_hv (2 , dimensions , dtype = dtype , device = device )
158237
159- def add_leaf (self , value , path ) :
160- for i in path :
238+ def add_leaf (self , value : Tensor , path : List [ str ]) -> None :
239+ for idx , i in enumerate ( path ) :
161240 if i == "l" :
162- value = functional .bind (value , self .left )
241+ value = functional .bind (
242+ value , functional .permute (self .left , shifts = idx )
243+ )
163244 else :
164- value = functional .bind (value , self .right )
245+ value = functional .bind (
246+ value , functional .permute (self .right , shifts = idx )
247+ )
248+
165249 self .value = functional .bundle (self .value , value )
166250
167251 @property
168- def left (self ):
252+ def left (self ) -> Tensor :
169253 return self .l_r [0 ]
170254
171255 @property
172- def right (self ):
256+ def right (self ) -> Tensor :
173257 return self .l_r [1 ]
174258
259+ def get_leaf (self , path : List [str ]) -> Tensor :
260+ for idx , i in enumerate (path ):
261+ if i == "l" :
262+ if idx == 0 :
263+ hv_path = self .left
264+ else :
265+ hv_path = functional .bind (
266+ hv_path , functional .permute (self .left , shifts = idx )
267+ )
268+ else :
269+ if idx == 0 :
270+ hv_path = self .right
271+ else :
272+ hv_path = functional .bind (
273+ hv_path , functional .permute (self .right , shifts = idx )
274+ )
275+
276+ return functional .bind (hv_path , self .value )
277+
175278
176279class FiniteStateAutomata :
177280 def __init__ (self , dimensions , device = None , dtype = None ):
@@ -180,18 +283,18 @@ def __init__(self, dimensions, device=None, dtype=None):
180283
181284 def add_transition (
182285 self ,
183- token : torch . Tensor ,
184- initial_state : torch . Tensor ,
185- final_state : torch . Tensor ,
186- ):
286+ token : Tensor ,
287+ initial_state : Tensor ,
288+ final_state : Tensor ,
289+ ) -> None :
187290 transition_edge = functional .bind (
188291 initial_state , functional .permute (final_state )
189292 )
190293 transition = functional .bind (token , transition_edge )
191294 self .value = functional .bundle (self .value , transition )
192295
193- def change_state (self , token : torch . Tensor , current_state : torch . Tensor ):
296+ def transition (self , state : Tensor , action : Tensor ) -> Tensor :
194297 # Returns the next state + some noise
195- next_state = functional .bind (self .value , current_state )
196- next_state = functional .bind (next_state , token )
298+ next_state = functional .bind (self .value , state )
299+ next_state = functional .bind (next_state , action )
197300 return functional .permute (next_state , shifts = - 1 )
0 commit comments