@@ -126,8 +126,9 @@ def _test_namedtuple(res, fields, func_name):
126126def test_cholesky (x , kw ):
127127 res = linalg .cholesky (x , ** kw )
128128
129- assert res .shape == x .shape , "cholesky() did not return the correct shape"
130- assert res .dtype == x .dtype , "cholesky() did not return the correct dtype"
129+ ph .assert_dtype ("cholesky" , in_dtype = x .dtype , out_dtype = res .dtype )
130+ ph .assert_result_shape ("cholesky" , in_shapes = [x .shape ],
131+ out_shape = res .shape , expected = x .shape )
131132
132133 _test_stacks (linalg .cholesky , x , ** kw , res = res )
133134
@@ -192,7 +193,7 @@ def test_cross(x1_x2_kw):
192193
193194 ph .assert_dtype ("cross" , in_dtype = [x1 .dtype , x2 .dtype ],
194195 out_dtype = res .dtype )
195- ph .assert_shape ("cross" , out_shape = res .shape , expected = broadcasted_shape )
196+ ph .assert_result_shape ("cross" , in_shapes = [ x1 . shape , x2 . shape ] , out_shape = res .shape , expected = broadcasted_shape )
196197
197198 def exact_cross (a , b ):
198199 assert a .shape == b .shape == (3 ,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
@@ -214,8 +215,9 @@ def exact_cross(a, b):
214215def test_det (x ):
215216 res = linalg .det (x )
216217
217- assert res .dtype == x .dtype , "det() did not return the correct dtype"
218- assert res .shape == x .shape [:- 2 ], "det() did not return the correct shape"
218+ ph .assert_dtype ("det" , in_dtype = x .dtype , out_dtype = res .dtype )
219+ ph .assert_result_shape ("det" , in_shapes = [x .shape ], out_shape = res .shape ,
220+ expected = x .shape [:- 2 ])
219221
220222 _test_stacks (linalg .det , x , res = res , dims = 0 )
221223
@@ -231,7 +233,7 @@ def test_det(x):
231233def test_diagonal (x , kw ):
232234 res = linalg .diagonal (x , ** kw )
233235
234- assert res . dtype == x .dtype , "diagonal() returned the wrong dtype"
236+ ph . assert_dtype ( "diagonal" , in_dtype = x .dtype , out_dtype = res . dtype )
235237
236238 n , m = x .shape [- 2 :]
237239 offset = kw .get ('offset' , 0 )
@@ -245,7 +247,9 @@ def test_diagonal(x, kw):
245247 else :
246248 diag_size = min (n , m , max (m - offset , 0 ))
247249
248- assert res .shape == (* x .shape [:- 2 ], diag_size ), "diagonal() returned the wrong shape"
250+ expected_shape = (* x .shape [:- 2 ], diag_size )
251+ ph .assert_result_shape ("diagonal" , in_shapes = [x .shape ],
252+ out_shape = res .shape , expected = expected_shape )
249253
250254 def true_diag (x_stack , offset = 0 ):
251255 if offset >= 0 :
@@ -266,11 +270,18 @@ def test_eigh(x):
266270 eigenvalues = res .eigenvalues
267271 eigenvectors = res .eigenvectors
268272
269- assert eigenvalues .dtype == x .dtype , "eigh().eigenvalues did not return the correct dtype"
270- assert eigenvalues .shape == x .shape [:- 1 ], "eigh().eigenvalues did not return the correct shape"
273+ ph .assert_dtype ("eigh" , in_dtype = x .dtype , out_dtype = eigenvalues .dtype ,
274+ expected = x .dtype , repr_name = "eigenvalues.dtype" )
275+ ph .assert_result_shape ("eigh" , in_shapes = [x .shape ],
276+ out_shape = eigenvalues .shape ,
277+ expected = x .shape [:- 1 ],
278+ repr_name = "eigenvalues.shape" )
271279
272- assert eigenvectors .dtype == x .dtype , "eigh().eigenvectors did not return the correct dtype"
273- assert eigenvectors .shape == x .shape , "eigh().eigenvectors did not return the correct shape"
280+ ph .assert_dtype ("eigh" , in_dtype = x .dtype , out_dtype = eigenvectors .dtype ,
281+ expected = x .dtype , repr_name = "eigenvectors.dtype" )
282+ ph .assert_result_shape ("eigh" , in_shapes = [x .shape ],
283+ out_shape = eigenvectors .shape , expected = x .shape ,
284+ repr_name = "eigenvectors.shape" )
274285
275286 # Note: _test_stacks here is only testing the shape and dtype. The actual
276287 # eigenvalues and eigenvectors may not be equal at all, since there is not
@@ -292,8 +303,9 @@ def test_eigh(x):
292303def test_eigvalsh (x ):
293304 res = linalg .eigvalsh (x )
294305
295- assert res .dtype == x .dtype , "eigvalsh() did not return the correct dtype"
296- assert res .shape == x .shape [:- 1 ], "eigvalsh() did not return the correct shape"
306+ ph .assert_dtype ("eigvalsh" , in_dtype = x .dtype , out_dtype = res .dtype )
307+ ph .assert_result_shape ("eigvalsh" , in_shapes = [x .shape ],
308+ out_shape = res .shape , expected = x .shape [:- 1 ])
297309
298310 # Note: _test_stacks here is only testing the shape and dtype. The actual
299311 # eigenvalues may not be equal at all, since there is not requirements or
@@ -311,8 +323,9 @@ def test_eigvalsh(x):
311323def test_inv (x ):
312324 res = linalg .inv (x )
313325
314- assert res .shape == x .shape , "inv() did not return the correct shape"
315- assert res .dtype == x .dtype , "inv() did not return the correct dtype"
326+ ph .assert_dtype ("inv" , in_dtype = x .dtype , out_dtype = res .dtype )
327+ ph .assert_result_shape ("inv" , in_shapes = [x .shape ], out_shape = res .shape ,
328+ expected = x .shape )
316329
317330 _test_stacks (linalg .inv , x , res = res )
318331
@@ -339,18 +352,24 @@ def test_matmul(x1, x2):
339352 ph .assert_dtype ("matmul" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
340353
341354 if len (x1 .shape ) == len (x2 .shape ) == 1 :
342- assert res .shape == ()
355+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
356+ out_shape = res .shape , expected = ())
343357 elif len (x1 .shape ) == 1 :
344- assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
358+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
359+ out_shape = res .shape ,
360+ expected = x2 .shape [:- 2 ] + x2 .shape [- 1 :])
345361 _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
346362 matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
347363 elif len (x2 .shape ) == 1 :
348- assert res .shape == x1 .shape [:- 1 ]
364+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
365+ out_shape = res .shape , expected = x1 .shape [:- 1 ])
349366 _test_stacks (_array_module .matmul , x1 , x2 , res = res , dims = 1 ,
350367 matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
351368 else :
352369 stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
353- assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
370+ ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
371+ out_shape = res .shape ,
372+ expected = stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ]))
354373 _test_stacks (_array_module .matmul , x1 , x2 , res = res )
355374
356375@pytest .mark .xp_extension ('linalg' )
@@ -370,8 +389,9 @@ def test_matrix_norm(x, kw):
370389 expected_shape = x .shape [:- 2 ] + (1 , 1 )
371390 else :
372391 expected_shape = x .shape [:- 2 ]
373- assert res .shape == expected_shape , f"matrix_norm({ keepdims = } ) did not return the correct shape"
374- assert res .dtype == x .dtype , "matrix_norm() did not return the correct dtype"
392+ ph .assert_dtype ("matrix_norm" , in_dtype = x .dtype , out_dtype = res .dtype )
393+ ph .assert_result_shape ("matrix_norm" , in_shapes = [x .shape ],
394+ out_shape = res .shape , expected = expected_shape )
375395
376396 _test_stacks (linalg .matrix_norm , x , ** kw , dims = 2 if keepdims else 0 ,
377397 res = res )
@@ -388,8 +408,9 @@ def test_matrix_norm(x, kw):
388408def test_matrix_power (x , n ):
389409 res = linalg .matrix_power (x , n )
390410
391- assert res .shape == x .shape , "matrix_power() did not return the correct shape"
392- assert res .dtype == x .dtype , "matrix_power() did not return the correct dtype"
411+ ph .assert_dtype ("matrix_power" , in_dtype = x .dtype , out_dtype = res .dtype )
412+ ph .assert_result_shape ("matrix_power" , in_shapes = [x .shape ],
413+ out_shape = res .shape , expected = x .shape )
393414
394415 if n == 0 :
395416 true_val = lambda x : _array_module .eye (x .shape [0 ], dtype = x .dtype )
@@ -419,8 +440,9 @@ def test_matrix_transpose(x):
419440 shape = list (x .shape )
420441 shape [- 1 ], shape [- 2 ] = shape [- 2 ], shape [- 1 ]
421442 shape = tuple (shape )
422- assert res .shape == shape , "matrix_transpose() did not return the correct shape"
423- assert res .dtype == x .dtype , "matrix_transpose() did not return the correct dtype"
443+ ph .assert_dtype ("matrix_transpose" , in_dtype = x .dtype , out_dtype = res .dtype )
444+ ph .assert_result_shape ("matrix_transpose" , in_shapes = [x .shape ],
445+ out_shape = res .shape , expected = shape )
424446
425447 _test_stacks (_array_module .matrix_transpose , x , res = res , true_val = true_val )
426448
@@ -435,8 +457,9 @@ def test_outer(x1, x2):
435457 res = linalg .outer (x1 , x2 )
436458
437459 shape = (x1 .shape [0 ], x2 .shape [0 ])
438- assert res .shape == shape , "outer() did not return the correct shape"
439- assert res .dtype == dh .result_type (x1 .dtype , x2 .dtype ), "outer() did not return the correct dtype"
460+ ph .assert_dtype ("outer" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
461+ ph .assert_result_shape ("outer" , in_shapes = [x1 .shape , x2 .shape ],
462+ out_shape = res .shape , expected = shape )
440463
441464 if 0 in shape :
442465 true_res = _array_module .empty (shape , dtype = res .dtype )
@@ -472,17 +495,23 @@ def test_qr(x, kw):
472495 Q = res .Q
473496 R = res .R
474497
475- assert Q .dtype == x .dtype , "qr().Q did not return the correct dtype"
498+ ph .assert_dtype ("qr" , in_dtype = x .dtype , out_dtype = Q .dtype ,
499+ expected = x .dtype , repr_name = "Q.dtype" )
476500 if mode == 'complete' :
477- assert Q . shape == x .shape [:- 2 ] + (M , M ), "qr().Q did not return the correct shape"
501+ expected_Q_shape = x .shape [:- 2 ] + (M , M )
478502 else :
479- assert Q .shape == x .shape [:- 2 ] + (M , K ), "qr().Q did not return the correct shape"
503+ expected_Q_shape = x .shape [:- 2 ] + (M , K )
504+ ph .assert_result_shape ("qr" , in_shapes = [x .shape ], out_shape = Q .shape ,
505+ expected = expected_Q_shape , repr_name = "Q.shape" )
480506
481- assert R .dtype == x .dtype , "qr().R did not return the correct dtype"
507+ ph .assert_dtype ("qr" , in_dtype = x .dtype , out_dtype = R .dtype ,
508+ expected = x .dtype , repr_name = "R.dtype" )
482509 if mode == 'complete' :
483- assert R . shape == x .shape [:- 2 ] + (M , N ), "qr().R did not return the correct shape"
510+ expected_R_shape = x .shape [:- 2 ] + (M , N )
484511 else :
485- assert R .shape == x .shape [:- 2 ] + (K , N ), "qr().R did not return the correct shape"
512+ expected_R_shape = x .shape [:- 2 ] + (K , N )
513+ ph .assert_result_shape ("qr" , in_shapes = [x .shape ], out_shape = R .shape ,
514+ expected = expected_R_shape , repr_name = "R.shape" )
486515
487516 _test_stacks (lambda x : linalg .qr (x , ** kw ).Q , x , res = Q )
488517 _test_stacks (lambda x : linalg .qr (x , ** kw ).R , x , res = R )
@@ -505,14 +534,17 @@ def test_slogdet(x):
505534
506535 ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = sign .dtype ,
507536 expected = x .dtype , repr_name = "sign.dtype" )
508- ph .assert_shape ("slogdet" , out_shape = sign .shape , expected = x .shape [:- 2 ],
509- repr_name = "sign.shape" )
537+ ph .assert_result_shape ("slogdet" , in_shapes = [x .shape ],
538+ out_shape = sign .shape ,
539+ expected = x .shape [:- 2 ],
540+ repr_name = "sign.shape" )
510541 expected_dtype = dh .as_real_dtype (x .dtype )
511542 ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = logabsdet .dtype ,
512543 expected = expected_dtype , repr_name = "logabsdet.dtype" )
513- ph .assert_shape ("slogdet" , out_shape = logabsdet .shape ,
514- expected = x .shape [:- 2 ],
515- repr_name = "logabsdet.shape" )
544+ ph .assert_result_shape ("slogdet" , in_shapes = [x .shape ],
545+ out_shape = logabsdet .shape ,
546+ expected = x .shape [:- 2 ],
547+ repr_name = "logabsdet.shape" )
516548
517549 _test_stacks (lambda x : linalg .slogdet (x ).sign , x ,
518550 res = sign , dims = 0 )
@@ -584,17 +616,31 @@ def test_svd(x, kw):
584616
585617 U , S , Vh = res
586618
587- assert U .dtype == x .dtype , "svd().U did not return the correct dtype"
588- assert S .dtype == x .dtype , "svd().S did not return the correct dtype"
589- assert Vh .dtype == x .dtype , "svd().Vh did not return the correct dtype"
619+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = U .dtype ,
620+ expected = x .dtype , repr_name = "U.dtype" )
621+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = S .dtype ,
622+ expected = x .dtype , repr_name = "S.dtype" )
623+ ph .assert_dtype ("svd" , in_dtype = x .dtype , out_dtype = Vh .dtype ,
624+ expected = x .dtype , repr_name = "Vh.dtype" )
590625
591626 if full_matrices :
592- assert U . shape == (* stack , M , M ), "svd().U did not return the correct shape"
593- assert Vh . shape == (* stack , N , N ), "svd().Vh did not return the correct shape"
627+ expected_U_shape = (* stack , M , M )
628+ expected_Vh_shape = (* stack , N , N )
594629 else :
595- assert U .shape == (* stack , M , K ), "svd(full_matrices=False).U did not return the correct shape"
596- assert Vh .shape == (* stack , K , N ), "svd(full_matrices=False).Vh did not return the correct shape"
597- assert S .shape == (* stack , K ), "svd().S did not return the correct shape"
630+ expected_U_shape = (* stack , M , K )
631+ expected_Vh_shape = (* stack , K , N )
632+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
633+ out_shape = U .shape ,
634+ expected = expected_U_shape ,
635+ repr_name = "U.shape" )
636+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
637+ out_shape = Vh .shape ,
638+ expected = expected_Vh_shape ,
639+ repr_name = "Vh.shape" )
640+ ph .assert_result_shape ("svd" , in_shapes = [x .shape ],
641+ out_shape = S .shape ,
642+ expected = (* stack , K ),
643+ repr_name = "S.shape" )
598644
599645 # The values of s must be sorted from largest to smallest
600646 if K >= 1 :
@@ -614,8 +660,11 @@ def test_svdvals(x):
614660 * stack , M , N = x .shape
615661 K = min (M , N )
616662
617- assert res .dtype == x .dtype , "svdvals() did not return the correct dtype"
618- assert res .shape == (* stack , K ), "svdvals() did not return the correct shape"
663+ ph .assert_dtype ("svdvals" , in_dtype = x .dtype , out_dtype = res .dtype ,
664+ expected = x .dtype )
665+ ph .assert_result_shape ("svdvals" , in_shapes = [x .shape ],
666+ out_shape = res .shape ,
667+ expected = (* stack , K ))
619668
620669 # SVD values must be sorted from largest to smallest
621670 assert _array_module .all (res [..., :- 1 ] >= res [..., 1 :]), "svdvals() values are not sorted from largest to smallest"
@@ -753,7 +802,7 @@ def test_trace(x, kw):
753802 # assert res.dtype == x.dtype, "trace() returned the wrong dtype"
754803
755804 n , m = x .shape [- 2 :]
756- assert res .shape == x .shape [:- 2 ], "trace() returned the wrong shape"
805+ ph . assert_result_shape ( 'trace' , x . shape , res .shape , expected = x .shape [:- 2 ])
757806
758807 def true_trace (x_stack , offset = 0 ):
759808 # Note: the spec does not specify that offset must be within the
@@ -799,7 +848,8 @@ def test_vecdot(x1, x2, data):
799848
800849 ph .assert_dtype ("vecdot" , in_dtype = [x1 .dtype , x2 .dtype ],
801850 out_dtype = res .dtype )
802- ph .assert_shape ("vecdot" , out_shape = res .shape , expected = expected_shape )
851+ ph .assert_result_shape ("vecdot" , in_shapes = [x1 .shape , x2 .shape ],
852+ out_shape = res .shape , expected = expected_shape )
803853
804854 if x1 .dtype in dh .int_dtypes :
805855 def true_val (x , y , axis = - 1 ):
0 commit comments