77import dpctl
88import dpctl .tensor as dpt
99from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
10+ from dpctl .utils import ExecutionPlacementError
1011
1112from .utils import _all_dtypes , _compare_dtypes , _usm_types
1213
@@ -32,6 +33,10 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
3233 assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
3334 assert r .sycl_queue == ar1 .sycl_queue
3435
36+ out = dpt .empty_like (ar1 , dtype = expected_dtype )
37+ dpt .add (ar1 , ar2 , out )
38+ assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
39+
3540 ar3 = dpt .ones (sz , dtype = op1_dtype )
3641 ar4 = dpt .ones (2 * sz , dtype = op2_dtype )
3742
@@ -44,6 +49,10 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
4449 assert r .shape == ar3 .shape
4550 assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
4651
52+ out = dpt .empty_like (ar1 , dtype = expected_dtype )
53+ dpt .add (ar3 [::- 1 ], ar4 [::2 ], out )
54+ assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
55+
4756
4857@pytest .mark .parametrize ("op1_usm_type" , _usm_types )
4958@pytest .mark .parametrize ("op2_usm_type" , _usm_types )
@@ -105,7 +114,6 @@ def test_add_broadcasting():
105114 v = dpt .arange (5 , dtype = "i4" )
106115
107116 r = dpt .add (m , v )
108-
109117 assert (dpt .asnumpy (r ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
110118
111119 r2 = dpt .add (v , m )
@@ -180,26 +188,8 @@ def __sycl_usm_array_interface__(self):
180188 dpt .add (a , c )
181189
182190
183- @pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
184- @pytest .mark .parametrize ("op2_dtype" , _all_dtypes )
185- def test_add_dtype_out_keyword (op1_dtype , op2_dtype ):
186- q = get_queue_or_skip ()
187- skip_if_dtype_not_supported (op1_dtype , q )
188- skip_if_dtype_not_supported (op2_dtype , q )
189-
190- sz = 127
191- ar1 = dpt .ones (sz , dtype = op1_dtype )
192- ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
193-
194- r = dpt .add (ar1 , ar2 )
195-
196- y = dpt .zeros_like (ar1 , dtype = r .dtype )
197- dpt .add (ar1 , ar2 , y )
198-
199- assert np .array_equal (dpt .asnumpy (r ), dpt .asnumpy (y ))
200-
201-
202191def test_add_errors ():
192+ get_queue_or_skip ()
203193 try :
204194 gpu_queue = dpctl .SyclQueue ("gpu" )
205195 except dpctl .SyclQueueCreationError :
@@ -245,11 +235,14 @@ def test_add_errors():
245235 y ,
246236 )
247237
248- ar1 = dpt .ones (2 , dtype = "float32" )
249- ar2 = dpt .ones_like (ar1 , dtype = "int32" )
250- y = dpt .empty_like (ar1 , dtype = "int32" )
238+ ar1 = np .ones (2 , dtype = "float32" )
239+ ar2 = np .ones_like (ar1 , dtype = "int32" )
251240 assert_raises_regex (
252- TypeError , "Output array of type.*is needed" , dpt .add , ar1 , ar2 , y
241+ ExecutionPlacementError ,
242+ "Execution placement can not be unambiguously inferred.*" ,
243+ dpt .add ,
244+ ar1 ,
245+ ar2 ,
253246 )
254247
255248 ar1 = dpt .ones (2 , dtype = "float32" )
@@ -263,3 +256,19 @@ def test_add_errors():
263256 ar2 ,
264257 y ,
265258 )
259+
260+
261+ @pytest .mark .parametrize ("dtype" , _all_dtypes )
262+ def test_add_dtype_error (
263+ dtype ,
264+ ):
265+ q = get_queue_or_skip ()
266+ skip_if_dtype_not_supported (dtype , q )
267+
268+ ar1 = dpt .ones (5 , dtype = dtype )
269+ ar2 = dpt .ones_like (ar1 , dtype = "f8" )
270+
271+ y = dpt .zeros_like (ar1 , dtype = "int8" )
272+ assert_raises_regex (
273+ TypeError , "Output array of type.*is needed" , dpt .add , ar1 , ar2 , y
274+ )
0 commit comments