@@ -108,6 +108,7 @@ def __init__(
108108 sparse : bool = False ,
109109 device = None ,
110110 dtype = None ,
111+ ** kwargs ,
111112 ) -> None :
112113 factory_kwargs = {"device" : device , "dtype" : dtype }
113114 # Have to call Module init explicitly in order not to use the Embedding init
@@ -116,6 +117,7 @@ def __init__(
116117 self .num_embeddings = num_embeddings
117118 self .embedding_dim = embedding_dim
118119 self .vsa = vsa
120+ self .vsa_kwargs = kwargs
119121
120122 if padding_idx is not None :
121123 if padding_idx > 0 :
@@ -135,7 +137,7 @@ def __init__(
135137 self .sparse = sparse
136138
137139 embeddings = functional .empty (
138- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
140+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
139141 )
140142 # Have to provide requires grad at the creation of the parameters to
141143 # prevent errors when instantiating a non-float embedding
@@ -148,7 +150,11 @@ def reset_parameters(self) -> None:
148150
149151 with torch .no_grad ():
150152 embeddings = functional .empty (
151- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
153+ self .num_embeddings ,
154+ self .embedding_dim ,
155+ self .vsa ,
156+ ** factory_kwargs ,
157+ ** self .vsa_kwargs ,
152158 )
153159 self .weight .copy_ (embeddings )
154160
@@ -214,6 +220,7 @@ def __init__(
214220 sparse : bool = False ,
215221 device = None ,
216222 dtype = None ,
223+ ** kwargs ,
217224 ) -> None :
218225 factory_kwargs = {"device" : device , "dtype" : dtype }
219226 # Have to call Module init explicitly in order not to use the Embedding init
@@ -222,6 +229,7 @@ def __init__(
222229 self .num_embeddings = num_embeddings
223230 self .embedding_dim = embedding_dim
224231 self .vsa = vsa
232+ self .vsa_kwargs = kwargs
225233
226234 if padding_idx is not None :
227235 if padding_idx > 0 :
@@ -241,7 +249,7 @@ def __init__(
241249 self .sparse = sparse
242250
243251 embeddings = functional .identity (
244- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
252+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
245253 )
246254 # Have to provide requires grad at the creation of the parameters to
247255 # prevent errors when instantiating a non-float embedding
@@ -254,7 +262,11 @@ def reset_parameters(self) -> None:
254262
255263 with torch .no_grad ():
256264 embeddings = functional .identity (
257- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
265+ self .num_embeddings ,
266+ self .embedding_dim ,
267+ self .vsa ,
268+ ** factory_kwargs ,
269+ ** self .vsa_kwargs ,
258270 )
259271 self .weight .copy_ (embeddings )
260272
@@ -266,7 +278,7 @@ def _fill_padding_idx_with_empty(self) -> None:
266278 if self .padding_idx is not None :
267279 with torch .no_grad ():
268280 empty = functional .empty (
269- 1 , self .embedding_dim , self .vsa , ** factory_kwargs
281+ 1 , self .embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
270282 )
271283 self .weight [self .padding_idx ].copy_ (empty .squeeze (0 ))
272284
@@ -332,6 +344,7 @@ def __init__(
332344 sparse : bool = False ,
333345 device = None ,
334346 dtype = None ,
347+ ** kwargs ,
335348 ) -> None :
336349 factory_kwargs = {"device" : device , "dtype" : dtype }
337350 # Have to call Module init explicitly in order not to use the Embedding init
@@ -340,6 +353,7 @@ def __init__(
340353 self .num_embeddings = num_embeddings
341354 self .embedding_dim = embedding_dim
342355 self .vsa = vsa
356+ self .vsa_kwargs = kwargs
343357
344358 if padding_idx is not None :
345359 if padding_idx > 0 :
@@ -359,7 +373,7 @@ def __init__(
359373 self .sparse = sparse
360374
361375 embeddings = functional .random (
362- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
376+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
363377 )
364378 # Have to provide requires grad at the creation of the parameters to
365379 # prevent errors when instantiating a non-float embedding
@@ -372,7 +386,11 @@ def reset_parameters(self) -> None:
372386
373387 with torch .no_grad ():
374388 embeddings = functional .random (
375- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
389+ self .num_embeddings ,
390+ self .embedding_dim ,
391+ self .vsa ,
392+ ** factory_kwargs ,
393+ ** self .vsa_kwargs ,
376394 )
377395 self .weight .copy_ (embeddings )
378396
@@ -384,7 +402,7 @@ def _fill_padding_idx_with_empty(self) -> None:
384402 if self .padding_idx is not None :
385403 with torch .no_grad ():
386404 empty = functional .empty (
387- 1 , self .embedding_dim , self .vsa , ** factory_kwargs
405+ 1 , self .embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
388406 )
389407 self .weight [self .padding_idx ].copy_ (empty .squeeze (0 ))
390408
@@ -469,6 +487,7 @@ def __init__(
469487 sparse : bool = False ,
470488 device = None ,
471489 dtype = None ,
490+ ** kwargs ,
472491 ) -> None :
473492 factory_kwargs = {"device" : device , "dtype" : dtype }
474493 # Have to call Module init explicitly in order not to use the Embedding init
@@ -477,6 +496,7 @@ def __init__(
477496 self .num_embeddings = num_embeddings
478497 self .embedding_dim = embedding_dim
479498 self .vsa = vsa
499+ self .vsa_kwargs = kwargs
480500 self .low = low
481501 self .high = high
482502 self .randomness = randomness
@@ -493,6 +513,7 @@ def __init__(
493513 self .vsa ,
494514 randomness = randomness ,
495515 ** factory_kwargs ,
516+ ** self .vsa_kwargs ,
496517 )
497518 # Have to provide requires grad at the creation of the parameters to
498519 # prevent errors when instantiating a non-float embedding
@@ -508,6 +529,7 @@ def reset_parameters(self) -> None:
508529 self .vsa ,
509530 randomness = self .randomness ,
510531 ** factory_kwargs ,
532+ ** self .vsa_kwargs ,
511533 )
512534 self .weight .copy_ (embeddings )
513535
@@ -592,6 +614,7 @@ def __init__(
592614 sparse : bool = False ,
593615 device = None ,
594616 dtype = None ,
617+ ** kwargs ,
595618 ) -> None :
596619 factory_kwargs = {"device" : device , "dtype" : dtype }
597620 # Have to call Module init explicitly in order not to use the Embedding init
@@ -600,6 +623,7 @@ def __init__(
600623 self .num_embeddings = num_embeddings
601624 self .embedding_dim = embedding_dim
602625 self .vsa = vsa
626+ self .vsa_kwargs = kwargs
603627 self .low = low
604628 self .high = high
605629
@@ -610,7 +634,7 @@ def __init__(
610634 self .sparse = sparse
611635
612636 embeddings = functional .thermometer (
613- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
637+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
614638 )
615639 # Have to provide requires grad at the creation of the parameters to
616640 # prevent errors when instantiating a non-float embedding
@@ -621,7 +645,11 @@ def reset_parameters(self) -> None:
621645
622646 with torch .no_grad ():
623647 embeddings = functional .thermometer (
624- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
648+ self .num_embeddings ,
649+ self .embedding_dim ,
650+ self .vsa ,
651+ ** factory_kwargs ,
652+ ** self .vsa_kwargs ,
625653 )
626654 self .weight .copy_ (embeddings )
627655
@@ -704,6 +732,7 @@ def __init__(
704732 sparse : bool = False ,
705733 device = None ,
706734 dtype = None ,
735+ ** kwargs ,
707736 ) -> None :
708737 factory_kwargs = {"device" : device , "dtype" : dtype }
709738 # Have to call Module init explicitly in order not to use the Embedding init
@@ -712,6 +741,7 @@ def __init__(
712741 self .num_embeddings = num_embeddings
713742 self .embedding_dim = embedding_dim
714743 self .vsa = vsa
744+ self .vsa_kwargs = kwargs
715745 self .phase = phase
716746 self .period = period
717747 self .randomness = randomness
@@ -728,6 +758,7 @@ def __init__(
728758 self .vsa ,
729759 randomness = randomness ,
730760 ** factory_kwargs ,
761+ ** self .vsa_kwargs ,
731762 )
732763 # Have to provide requires grad at the creation of the parameters to
733764 # prevent errors when instantiating a non-float embedding
@@ -743,6 +774,7 @@ def reset_parameters(self) -> None:
743774 self .vsa ,
744775 randomness = self .randomness ,
745776 ** factory_kwargs ,
777+ ** self .vsa_kwargs ,
746778 )
747779 self .weight .copy_ (embeddings )
748780
@@ -945,6 +977,7 @@ def __init__(
945977 device = None ,
946978 dtype = None ,
947979 requires_grad : bool = False ,
980+ ** kwargs ,
948981 ):
949982 factory_kwargs = {
950983 "device" : device ,
@@ -954,10 +987,16 @@ def __init__(
954987 super (Density , self ).__init__ ()
955988
956989 # A set of random vectors used as unique IDs for features of the dataset.
957- self .key = Random (in_features , out_features , vsa , ** factory_kwargs )
990+ self .key = Random (in_features , out_features , vsa , ** factory_kwargs , ** kwargs )
958991 # Thermometer encoding used for transforming input data.
959992 self .density_encoding = Thermometer (
960- out_features + 1 , out_features , vsa , low = low , high = high , ** factory_kwargs
993+ out_features + 1 ,
994+ out_features ,
995+ vsa ,
996+ low = low ,
997+ high = high ,
998+ ** factory_kwargs ,
999+ ** kwargs ,
9611000 )
9621001
9631002 def reset_parameters (self ) -> None :
0 commit comments