@@ -597,32 +597,42 @@ def solve_args():
597597 of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
598598 where the ... parts of x1 and x2 are broadcast compatible.
599599 """
600+ mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dh .all_float_dtypes ))
601+
600602 stack_shapes = shared (two_mutually_broadcastable_shapes )
601603 # Don't worry about dtypes since all floating dtypes are type promotable
602604 # with each other.
603- x1 = shared (invertible_matrices (stack_shapes = stack_shapes .map (lambda pair :
604- pair [0 ])))
605+ x1 = shared (invertible_matrices (
606+ stack_shapes = stack_shapes .map (lambda pair : pair [0 ]),
607+ dtypes = mutual_dtypes .map (lambda pair : pair [0 ])))
605608
606609 @composite
607610 def _x2_shapes (draw ):
608611 end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
609612 return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
610613
611614 x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
612- x2 = arrays (dtype = all_floating_dtypes (), shape = x2_shapes )
615+ x2 = arrays (shape = x2_shapes , dtype = mutual_dtypes . map ( lambda pair : pair [ 1 ]) )
613616 return x1 , x2
614617
615618@pytest .mark .xp_extension ('linalg' )
616619@given (* solve_args ())
617620def test_solve (x1 , x2 ):
618621 res = linalg .solve (x1 , x2 )
619622
623+ ph .assert_dtype ("solve" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
620624 if x2 .ndim == 1 :
625+ expected_shape = x1 .shape [:- 2 ] + x2 .shape [- 1 :]
621626 _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 1 ,
622627 matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
623628 else :
629+ stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
630+ expected_shape = stack_shape + x2 .shape [- 2 :]
624631 _test_stacks (linalg .solve , x1 , x2 , res = res , dims = 2 )
625632
633+ ph .assert_result_shape ("solve" , in_shapes = [x1 .shape , x2 .shape ],
634+ out_shape = res .shape , expected = expected_shape )
635+
626636@pytest .mark .xp_extension ('linalg' )
627637@given (
628638 x = finite_matrices (),
0 commit comments