1818
1919import numpy as np
2020import pytest
21- from numpy .testing import assert_equal
21+ from numpy .testing import assert_allclose , assert_equal
2222
2323import dpctl .tensor as dpt
2424from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
@@ -50,7 +50,7 @@ def test_sqrt_output_contig(dtype):
5050 Y = dpt .sqrt (X )
5151 tol = 8 * dpt .finfo (Y .dtype ).resolution
5252
53- np . testing . assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
53+ assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
5454
5555
5656@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
@@ -66,7 +66,7 @@ def test_sqrt_output_strided(dtype):
6666 Y = dpt .sqrt (X )
6767 tol = 8 * dpt .finfo (Y .dtype ).resolution
6868
69- np . testing . assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
69+ assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
7070
7171
7272@pytest .mark .parametrize ("usm_type" , _usm_types )
@@ -89,7 +89,7 @@ def test_sqrt_usm_type(usm_type):
8989 expected_Y [..., 1 ::2 ] = np .sqrt (np .float32 (23.0 ))
9090 tol = 8 * dpt .finfo (Y .dtype ).resolution
9191
92- np . testing . assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
92+ assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
9393
9494
9595@pytest .mark .parametrize ("dtype" , _all_dtypes )
@@ -112,11 +112,10 @@ def test_sqrt_order(dtype):
112112 dpt .finfo (Y .dtype ).resolution ,
113113 np .finfo (expected_Y .dtype ).resolution ,
114114 )
115- np .testing .assert_allclose (
116- dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol
117- )
115+ assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
118116
119117
118+ @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
120119def test_sqrt_special_cases ():
121120 q = get_queue_or_skip ()
122121
@@ -126,3 +125,27 @@ def test_sqrt_special_cases():
126125 Xnp = dpt .asnumpy (X )
127126
128127 assert_equal (dpt .asnumpy (dpt .sqrt (X )), np .sqrt (Xnp ))
128+
129+
130+ @pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
131+ def test_sqrt_out_overlap (dtype ):
132+ q = get_queue_or_skip ()
133+ skip_if_dtype_not_supported (dtype , q )
134+
135+ X = dpt .linspace (0 , 35 , 60 , dtype = dtype , sycl_queue = q )
136+ X = dpt .reshape (X , (3 , 5 , 4 ))
137+
138+ Xnp = dpt .asnumpy (X )
139+ Ynp = np .sqrt (Xnp , out = Xnp )
140+
141+ Y = dpt .sqrt (X , out = X )
142+ assert Y is X
143+
144+ tol = 8 * dpt .finfo (Y .dtype ).resolution
145+ assert_allclose (dpt .asnumpy (X ), Xnp , atol = tol , rtol = tol )
146+
147+ Ynp = np .sqrt (Xnp , out = Xnp [::- 1 ])
148+ Y = dpt .sqrt (X , out = X [::- 1 ])
149+ assert Y is not X
150+ assert_allclose (dpt .asnumpy (X ), Xnp , atol = tol , rtol = tol )
151+ assert_allclose (dpt .asnumpy (Y ), Ynp , atol = tol , rtol = tol )
0 commit comments