@@ -198,6 +198,37 @@ def test_dtype(self, vsa, dtype):
198198 else :
199199 assert similarity .dtype == torch .get_default_dtype ()
200200
201+ def test_custom_dtype (self ):
202+ hv = functional .random (3 , 100 , "BSBC" , block_size = 1024 )
203+ similarity = functional .dot_similarity (hv , hv )
204+ assert similarity .dtype == torch .get_default_dtype ()
205+
206+ similarity = functional .dot_similarity (hv , hv , dtype = torch .float64 )
207+ assert similarity .dtype == torch .float64
208+
209+ similarity = functional .dot_similarity (hv , hv , dtype = torch .int16 )
210+ assert similarity .dtype == torch .int16
211+
212+ hv = functional .random (3 , 100 , "MAP" )
213+ similarity = functional .dot_similarity (hv , hv )
214+ assert similarity .dtype == torch .get_default_dtype ()
215+
216+ similarity = functional .dot_similarity (hv , hv , dtype = torch .float64 )
217+ assert similarity .dtype == torch .float64
218+
219+ similarity = functional .dot_similarity (hv , hv , dtype = torch .int16 )
220+ assert similarity .dtype == torch .int16
221+
222+ hv = functional .random (3 , 100 , "BSC" )
223+ similarity = functional .dot_similarity (hv , hv )
224+ assert similarity .dtype == torch .get_default_dtype ()
225+
226+ similarity = functional .dot_similarity (hv , hv , dtype = torch .float64 )
227+ assert similarity .dtype == torch .float64
228+
229+ similarity = functional .dot_similarity (hv , hv , dtype = torch .int16 )
230+ assert similarity .dtype == torch .int16
231+
201232 @pytest .mark .parametrize ("vsa" , vsa_tensors )
202233 @pytest .mark .parametrize ("dtype" , torch_dtypes )
203234 def test_device (self , vsa , dtype ):
0 commit comments