@@ -206,11 +206,11 @@ def test_cross(x1_x2_kw):
206206
207207 def exact_cross (a , b ):
208208 assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
209- return asarray ([
209+ return asarray (xp . stack ( [
210210 a [1 ]* b [2 ] - a [2 ]* b [1 ],
211211 a [2 ]* b [0 ] - a [0 ]* b [2 ],
212212 a [0 ]* b [1 ] - a [1 ]* b [0 ],
213- ], dtype = res .dtype )
213+ ]) , dtype = res .dtype )
214214
215215 # We don't want to pass in **kw here because that would pass axis to
216216 # cross() on a single stack, but the axis is not meaningful on unstacked
@@ -267,7 +267,7 @@ def true_diag(x_stack, offset=0):
267267 x_stack_diag = [x_stack [i , i + offset ] for i in range (diag_size )]
268268 else :
269269 x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
270- return asarray (x_stack_diag , dtype = x .dtype )
270+ return asarray (xp . stack ( x_stack_diag ) if x_stack_diag else [] , dtype = x .dtype )
271271
272272 _test_stacks (linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
273273
@@ -901,7 +901,9 @@ def true_trace(x_stack, offset=0):
901901 x_stack_diag = [x_stack [i , i + offset ] for i in range (diag_size )]
902902 else :
903903 x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
904- return _array_module .sum (asarray (x_stack_diag , dtype = x .dtype ))
904+ result = xp .asarray (xp .stack (x_stack_diag ) if x_stack_diag else [], dtype = x .dtype )
905+ return _array_module .sum (result )
906+
905907
906908 _test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
907909
0 commit comments