@@ -331,10 +331,9 @@ def test_inv(x):
331331
332332 # TODO: Test that the result is actually the inverse
333333
334- @given (
335- * two_mutual_arrays (dh .real_dtypes )
336- )
337- def test_matmul (x1 , x2 ):
334+ def _test_matmul (namespace , x1 , x2 ):
335+ matmul = namespace .matmul
336+
338337 # TODO: Make this also test the @ operator
339338 if (x1 .shape == () or x2 .shape == ()
340339 or len (x1 .shape ) == len (x2 .shape ) == 1 and x1 .shape != x2 .shape
@@ -347,7 +346,7 @@ def test_matmul(x1, x2):
347346 "matmul did not raise an exception for invalid shapes" )
348347 return
349348 else :
350- res = _array_module . matmul (x1 , x2 )
349+ res = matmul (x1 , x2 )
351350
352351 ph .assert_dtype ("matmul" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
353352
@@ -358,19 +357,32 @@ def test_matmul(x1, x2):
358357 ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
359358 out_shape = res .shape ,
360359 expected = x2 .shape [:- 2 ] + x2 .shape [- 1 :])
361- _test_stacks (_array_module . matmul , x1 , x2 , res = res , dims = 1 ,
360+ _test_stacks (matmul , x1 , x2 , res = res , dims = 1 ,
362361 matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
363362 elif len (x2 .shape ) == 1 :
364363 ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
365364 out_shape = res .shape , expected = x1 .shape [:- 1 ])
366- _test_stacks (_array_module . matmul , x1 , x2 , res = res , dims = 1 ,
365+ _test_stacks (matmul , x1 , x2 , res = res , dims = 1 ,
367366 matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
368367 else :
369368 stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
370369 ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
371370 out_shape = res .shape ,
372371 expected = stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ]))
373- _test_stacks (_array_module .matmul , x1 , x2 , res = res )
372+ _test_stacks (matmul , x1 , x2 , res = res )
373+
374+ @pytest .mark .xp_extension ('linalg' )
375+ @given (
376+ * two_mutual_arrays (dh .real_dtypes )
377+ )
378+ def test_linalg_matmul (x1 , x2 ):
379+ return _test_matmul (linalg , x1 , x2 )
380+
381+ @given (
382+ * two_mutual_arrays (dh .real_dtypes )
383+ )
384+ def test_matmul (x1 , x2 ):
385+ return _test_matmul (_array_module , x1 , x2 )
374386
375387@pytest .mark .xp_extension ('linalg' )
376388@given (
@@ -428,11 +440,9 @@ def test_matrix_power(x, n):
428440def test_matrix_rank (x , kw ):
429441 linalg .matrix_rank (x , ** kw )
430442
431- @given (
432- x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
433- )
434- def test_matrix_transpose (x ):
435- res = _array_module .matrix_transpose (x )
443+ def _test_matrix_transpose (namespace , x ):
444+ matrix_transpose = namespace .matrix_transpose
445+ res = matrix_transpose (x )
436446 true_val = lambda a : _array_module .asarray ([[a [i , j ] for i in
437447 range (a .shape [0 ])] for j in
438448 range (a .shape [1 ])],
@@ -444,7 +454,20 @@ def test_matrix_transpose(x):
444454 ph .assert_result_shape ("matrix_transpose" , in_shapes = [x .shape ],
445455 out_shape = res .shape , expected = shape )
446456
447- _test_stacks (_array_module .matrix_transpose , x , res = res , true_val = true_val )
457+ _test_stacks (matrix_transpose , x , res = res , true_val = true_val )
458+
459+ @pytest .mark .xp_extension ('linalg' )
460+ @given (
461+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
462+ )
463+ def test_linalg_matrix_transpose (x ):
464+ return _test_matrix_transpose (linalg , x )
465+
466+ @given (
467+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
468+ )
469+ def test_matrix_transpose (x ):
470+ return _test_matrix_transpose (_array_module , x )
448471
449472@pytest .mark .xp_extension ('linalg' )
450473@given (
@@ -759,12 +782,9 @@ def _test_tensordot_stacks(x1, x2, kw, res):
759782 decomp_res_stack = xp .tensordot (x1_stack , x2_stack , axes = res_axes )
760783 assert_equal (res_stack , decomp_res_stack )
761784
762- @given (
763- * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
764- tensordot_kw ,
765- )
766- def test_tensordot (x1 , x2 , kw ):
767- res = xp .tensordot (x1 , x2 , ** kw )
785+ def _test_tensordot (namespace , x1 , x2 , kw ):
786+ tensordot = namespace .tensordot
787+ res = tensordot (x1 , x2 , ** kw )
768788
769789 ph .assert_dtype ("tensordot" , in_dtype = [x1 .dtype , x2 .dtype ],
770790 out_dtype = res .dtype )
@@ -785,6 +805,21 @@ def test_tensordot(x1, x2, kw):
785805 expected = result_shape )
786806 _test_tensordot_stacks (x1 , x2 , kw , res )
787807
808+ @pytest .mark .xp_extension ('linalg' )
809+ @given (
810+ * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
811+ tensordot_kw ,
812+ )
813+ def test_linalg_tensordot (x1 , x2 , kw ):
814+ _test_tensordot (linalg , x1 , x2 , kw )
815+
816+ @given (
817+ * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
818+ tensordot_kw ,
819+ )
820+ def test_tensordot (x1 , x2 , kw ):
821+ _test_tensordot (_array_module , x1 , x2 , kw )
822+
788823@pytest .mark .xp_extension ('linalg' )
789824@given (
790825 x = arrays (dtype = xps .numeric_dtypes (), shape = matrix_shapes ()),
@@ -828,12 +863,8 @@ def true_trace(x_stack, offset=0):
828863
829864 _test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
830865
831-
832- @given (
833- * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
834- data (),
835- )
836- def test_vecdot (x1 , x2 , data ):
866+ def _test_vecdot (namespace , x1 , x2 , data ):
867+ vecdot = namespace .vecdot
837868 broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
838869 min_ndim = min (x1 .ndim , x2 .ndim )
839870 ndim = len (broadcasted_shape )
@@ -842,14 +873,14 @@ def test_vecdot(x1, x2, data):
842873 x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
843874 x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
844875 if x1_shape [axis ] != x2_shape [axis ]:
845- ph .raises (Exception , lambda : xp . vecdot (x1 , x2 , ** kw ),
876+ ph .raises (Exception , lambda : vecdot (x1 , x2 , ** kw ),
846877 "vecdot did not raise an exception for invalid shapes" )
847878 return
848879 expected_shape = list (broadcasted_shape )
849880 expected_shape .pop (axis )
850881 expected_shape = tuple (expected_shape )
851882
852- res = xp . vecdot (x1 , x2 , ** kw )
883+ res = vecdot (x1 , x2 , ** kw )
853884
854885 ph .assert_dtype ("vecdot" , in_dtype = [x1 .dtype , x2 .dtype ],
855886 out_dtype = res .dtype )
@@ -862,9 +893,25 @@ def true_val(x, y, axis=-1):
862893 else :
863894 true_val = None
864895
865- _test_stacks (linalg . vecdot , x1 , x2 , res = res , dims = 0 ,
896+ _test_stacks (vecdot , x1 , x2 , res = res , dims = 0 ,
866897 matrix_axes = (axis ,), true_val = true_val )
867898
899+
900+ @pytest .mark .xp_extension ('linalg' )
901+ @given (
902+ * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
903+ data (),
904+ )
905+ def test_linalg_vecdot (x1 , x2 , data ):
906+ _test_vecdot (linalg , x1 , x2 , data )
907+
908+ @given (
909+ * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
910+ data (),
911+ )
912+ def test_vecdot (x1 , x2 , data ):
913+ _test_vecdot (_array_module , x1 , x2 , data )
914+
868915# Insanely large orders might not work. There isn't a limit specified in the
869916# spec, so we just limit to reasonable values here.
870917max_ord = 100
0 commit comments