@@ -33,7 +33,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
3333 assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
3434 assert r .sycl_queue == ar1 .sycl_queue
3535
36- out = dpt .empty_like (ar1 , dtype = expected_dtype )
36+ out = dpt .empty_like (ar1 , dtype = r . dtype )
3737 dpt .add (ar1 , ar2 , out )
3838 assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
3939
@@ -49,7 +49,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
4949 assert r .shape == ar3 .shape
5050 assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
5151
52- out = dpt .empty_like (ar1 , dtype = expected_dtype )
52+ out = dpt .empty_like (ar1 , dtype = r . dtype )
5353 dpt .add (ar3 [::- 1 ], ar4 [::2 ], out )
5454 assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
5555
@@ -74,37 +74,49 @@ def test_add_usm_type_matrix(op1_usm_type, op2_usm_type):
7474def test_add_order ():
7575 get_queue_or_skip ()
7676
77- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
78- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
79- r1 = dpt .add (ar1 , ar2 , order = "C" )
80- assert r1 .flags .c_contiguous
81- r2 = dpt .add (ar1 , ar2 , order = "F" )
82- assert r2 .flags .f_contiguous
83- r3 = dpt .add (ar1 , ar2 , order = "A" )
84- assert r3 .flags .c_contiguous
85- r4 = dpt .add (ar1 , ar2 , order = "K" )
86- assert r4 .flags .c_contiguous
87-
88- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
89- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
90- r1 = dpt .add (ar1 , ar2 , order = "C" )
91- assert r1 .flags .c_contiguous
92- r2 = dpt .add (ar1 , ar2 , order = "F" )
93- assert r2 .flags .f_contiguous
94- r3 = dpt .add (ar1 , ar2 , order = "A" )
95- assert r3 .flags .f_contiguous
96- r4 = dpt .add (ar1 , ar2 , order = "K" )
97- assert r4 .flags .f_contiguous
98-
99- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
100- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
101- r4 = dpt .add (ar1 , ar2 , order = "K" )
102- assert r4 .strides == (20 , - 1 )
103-
104- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
105- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
106- r4 = dpt .add (ar1 , ar2 , order = "K" )
107- assert r4 .strides == (- 1 , 20 )
77+ test_shape = (
78+ 20 ,
79+ 20 ,
80+ )
81+ test_shape2 = tuple (2 * dim for dim in test_shape )
82+ n = test_shape [- 1 ]
83+
84+ for dt1 , dt2 in zip (["i4" , "i4" , "f4" ], ["i4" , "f4" , "i4" ]):
85+ ar1 = dpt .ones (test_shape , dtype = dt1 , order = "C" )
86+ ar2 = dpt .ones (test_shape , dtype = dt2 , order = "C" )
87+ r1 = dpt .add (ar1 , ar2 , order = "C" )
88+ assert r1 .flags .c_contiguous
89+ r2 = dpt .add (ar1 , ar2 , order = "F" )
90+ assert r2 .flags .f_contiguous
91+ r3 = dpt .add (ar1 , ar2 , order = "A" )
92+ assert r3 .flags .c_contiguous
93+ r4 = dpt .add (ar1 , ar2 , order = "K" )
94+ assert r4 .flags .c_contiguous
95+
96+ ar1 = dpt .ones (test_shape , dtype = dt1 , order = "F" )
97+ ar2 = dpt .ones (test_shape , dtype = dt2 , order = "F" )
98+ r1 = dpt .add (ar1 , ar2 , order = "C" )
99+ assert r1 .flags .c_contiguous
100+ r2 = dpt .add (ar1 , ar2 , order = "F" )
101+ assert r2 .flags .f_contiguous
102+ r3 = dpt .add (ar1 , ar2 , order = "A" )
103+ assert r3 .flags .f_contiguous
104+ r4 = dpt .add (ar1 , ar2 , order = "K" )
105+ assert r4 .flags .f_contiguous
106+
107+ ar1 = dpt .ones (test_shape2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ]
108+ ar2 = dpt .ones (test_shape2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ]
109+ r4 = dpt .add (ar1 , ar2 , order = "K" )
110+ assert r4 .strides == (n , - 1 )
111+ r5 = dpt .add (ar1 , ar2 , order = "C" )
112+ assert r5 .strides == (n , 1 )
113+
114+ ar1 = dpt .ones (test_shape2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ].mT
115+ ar2 = dpt .ones (test_shape2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ].mT
116+ r4 = dpt .add (ar1 , ar2 , order = "K" )
117+ assert r4 .strides == (- 1 , n )
118+ r5 = dpt .add (ar1 , ar2 , order = "C" )
119+ assert r5 .strides == (n , 1 )
108120
109121
110122def test_add_broadcasting ():
@@ -266,7 +278,7 @@ def test_add_dtype_error(
266278 skip_if_dtype_not_supported (dtype , q )
267279
268280 ar1 = dpt .ones (5 , dtype = dtype )
269- ar2 = dpt .ones_like (ar1 , dtype = "f8 " )
281+ ar2 = dpt .ones_like (ar1 , dtype = "f4 " )
270282
271283 y = dpt .zeros_like (ar1 , dtype = "int8" )
272284 assert_raises_regex (
0 commit comments