1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- import itertools
18- import os
19- import re
20-
2117import numpy as np
2218import pytest
2319from numpy .testing import assert_allclose
3430 (np .arctanh , dpt .atanh ),
3531]
3632_all_funcs = _hyper_funcs + _inv_hyper_funcs
37- _dpt_funcs = [t [1 ] for t in _all_funcs ]
3833
3934
4035@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -45,17 +40,10 @@ def test_hyper_out_type(np_call, dpt_call, dtype):
4540
4641 a = 1 if np_call == np .arccosh else 0
4742
48- X = dpt .asarray (a , dtype = dtype , sycl_queue = q )
49- expected_dtype = np_call (np .array (a , dtype = dtype )).dtype
50- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
51- assert dpt_call (X ).dtype == expected_dtype
52-
53- X = dpt .asarray (a , dtype = dtype , sycl_queue = q )
43+ x = dpt .asarray (a , dtype = dtype , sycl_queue = q )
5444 expected_dtype = np_call (np .array (a , dtype = dtype )).dtype
5545 expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
56- Y = dpt .empty_like (X , dtype = expected_dtype )
57- dpt_call (X , out = Y )
58- assert_allclose (dpt .asnumpy (dpt_call (X )), dpt .asnumpy (Y ))
46+ assert dpt_call (x ).dtype == expected_dtype
5947
6048
6149@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -119,79 +107,6 @@ def test_hyper_complex_contig(np_call, dpt_call, dtype):
119107 )
120108
121109
122- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
123- @pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
124- def test_hyper_usm_type (np_call , dpt_call , usm_type ):
125- q = get_queue_or_skip ()
126-
127- arg_dt = np .dtype ("f4" )
128- input_shape = (10 , 10 , 10 , 10 )
129- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
130- if np_call == np .arctanh :
131- X [..., 0 ::2 ] = - 0.4
132- X [..., 1 ::2 ] = 0.3
133- elif np_call == np .arccosh :
134- X [..., 0 ::2 ] = 2.2
135- X [..., 1 ::2 ] = 5.5
136- else :
137- X [..., 0 ::2 ] = - 4.4
138- X [..., 1 ::2 ] = 5.5
139-
140- Y = dpt_call (X )
141- assert Y .usm_type == X .usm_type
142- assert Y .sycl_queue == X .sycl_queue
143- assert Y .flags .c_contiguous
144-
145- expected_Y = np_call (dpt .asnumpy (X ))
146- tol = 8 * dpt .finfo (Y .dtype ).resolution
147- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
148-
149-
150- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
151- @pytest .mark .parametrize ("dtype" , _all_dtypes )
152- def test_hyper_order (np_call , dpt_call , dtype ):
153- q = get_queue_or_skip ()
154- skip_if_dtype_not_supported (dtype , q )
155-
156- arg_dt = np .dtype (dtype )
157- input_shape = (4 , 4 , 4 , 4 )
158- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
159- if np_call == np .arctanh :
160- X [..., 0 ::2 ] = - 0.4
161- X [..., 1 ::2 ] = 0.3
162- elif np_call == np .arccosh :
163- X [..., 0 ::2 ] = 2.2
164- X [..., 1 ::2 ] = 5.5
165- else :
166- X [..., 0 ::2 ] = - 4.4
167- X [..., 1 ::2 ] = 5.5
168-
169- for perms in itertools .permutations (range (4 )):
170- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
171- with np .errstate (all = "ignore" ):
172- expected_Y = np_call (dpt .asnumpy (U ))
173- for ord in ["C" , "F" , "A" , "K" ]:
174- Y = dpt_call (U , order = ord )
175- tol = 8 * max (
176- dpt .finfo (Y .dtype ).resolution ,
177- np .finfo (expected_Y .dtype ).resolution ,
178- )
179- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
180-
181-
182- @pytest .mark .parametrize ("callable" , _dpt_funcs )
183- @pytest .mark .parametrize ("dtype" , _all_dtypes )
184- def test_hyper_error_dtype (callable , dtype ):
185- q = get_queue_or_skip ()
186- skip_if_dtype_not_supported (dtype , q )
187-
188- x = dpt .ones (5 , dtype = dtype )
189- y = dpt .empty_like (x , dtype = "int16" )
190- with pytest .raises (ValueError ) as excinfo :
191- callable (x , out = y )
192- assert re .match ("Output array of type.*is needed" , str (excinfo .value ))
193-
194-
195110@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
196111@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
197112def test_hyper_real_strided (np_call , dpt_call , dtype ):
@@ -270,46 +185,3 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
270185
271186 tol = 8 * dpt .finfo (dtype ).resolution
272187 assert_allclose (dpt .asnumpy (dpt_call (yf )), Y_np , atol = tol , rtol = tol )
273-
274-
275- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
276- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
277- def test_hyper_complex_special_cases_conj_property (np_call , dpt_call , dtype ):
278- q = get_queue_or_skip ()
279- skip_if_dtype_not_supported (dtype , q )
280-
281- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
282- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
283-
284- Xc_np = np .array (xc , dtype = dtype )
285- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
286-
287- tol = 50 * dpt .finfo (dtype ).resolution
288- Y = dpt_call (Xc )
289- Yc = dpt_call (dpt .conj (Xc ))
290-
291- dpt .allclose (Y , dpt .conj (Yc ), atol = tol , rtol = tol )
292-
293-
294- @pytest .mark .skipif (
295- os .name != "posix" , reason = "Known to fail on Windows due to bug in NumPy"
296- )
297- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
298- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
299- def test_hyper_complex_special_cases (np_call , dpt_call , dtype ):
300- q = get_queue_or_skip ()
301- skip_if_dtype_not_supported (dtype , q )
302-
303- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
304- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
305-
306- Xc_np = np .array (xc , dtype = dtype )
307- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
308-
309- with np .errstate (all = "ignore" ):
310- Ynp = np_call (Xc_np )
311-
312- tol = 50 * dpt .finfo (dtype ).resolution
313- Y = dpt_call (Xc )
314- assert_allclose (dpt .asnumpy (dpt .real (Y )), np .real (Ynp ), atol = tol , rtol = tol )
315- assert_allclose (dpt .asnumpy (dpt .imag (Y )), np .imag (Ynp ), atol = tol , rtol = tol )
0 commit comments