1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- import itertools
18- import os
19- import re
20-
2117import numpy as np
2218import pytest
2319from numpy .testing import assert_allclose
3430 (np .arctan , dpt .atan ),
3531]
3632_all_funcs = _trig_funcs + _inv_trig_funcs
37- _dpt_funcs = [t [1 ] for t in _all_funcs ]
3833
3934
4035@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -43,17 +38,10 @@ def test_trig_out_type(np_call, dpt_call, dtype):
4338 q = get_queue_or_skip ()
4439 skip_if_dtype_not_supported (dtype , q )
4540
46- X = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
47- expected_dtype = np_call (np .array (0 , dtype = dtype )).dtype
48- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
49- assert dpt_call (X ).dtype == expected_dtype
50-
51- X = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
41+ x = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
5242 expected_dtype = np_call (np .array (0 , dtype = dtype )).dtype
5343 expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
54- Y = dpt .empty_like (X , dtype = expected_dtype )
55- dpt_call (X , out = Y )
56- assert_allclose (dpt .asnumpy (dpt_call (X )), dpt .asnumpy (Y ))
44+ assert dpt_call (x ).dtype == expected_dtype
5745
5846
5947@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -127,78 +115,6 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
127115 assert_allclose (dpt .asnumpy (Z ), expected , atol = tol , rtol = tol )
128116
129117
130- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
131- @pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
132- def test_trig_usm_type (np_call , dpt_call , usm_type ):
133- q = get_queue_or_skip ()
134-
135- arg_dt = np .dtype ("f4" )
136- input_shape = (10 , 10 , 10 , 10 )
137- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
138- if np_call in _trig_funcs :
139- X [..., 0 ::2 ] = np .pi / 6
140- X [..., 1 ::2 ] = np .pi / 3
141- if np_call == np .arctan :
142- X [..., 0 ::2 ] = - 2.2
143- X [..., 1 ::2 ] = 3.3
144- else :
145- X [..., 0 ::2 ] = - 0.3
146- X [..., 1 ::2 ] = 0.7
147-
148- Y = dpt_call (X )
149- assert Y .usm_type == X .usm_type
150- assert Y .sycl_queue == X .sycl_queue
151- assert Y .flags .c_contiguous
152-
153- expected_Y = np_call (dpt .asnumpy (X ))
154- tol = 8 * dpt .finfo (Y .dtype ).resolution
155- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
156-
157-
158- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
159- @pytest .mark .parametrize ("dtype" , _all_dtypes )
160- def test_trig_order (np_call , dpt_call , dtype ):
161- q = get_queue_or_skip ()
162- skip_if_dtype_not_supported (dtype , q )
163-
164- arg_dt = np .dtype (dtype )
165- input_shape = (4 , 4 , 4 , 4 )
166- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
167- if np_call in _trig_funcs :
168- X [..., 0 ::2 ] = np .pi / 6
169- X [..., 1 ::2 ] = np .pi / 3
170- if np_call == np .arctan :
171- X [..., 0 ::2 ] = - 2.2
172- X [..., 1 ::2 ] = 3.3
173- else :
174- X [..., 0 ::2 ] = - 0.3
175- X [..., 1 ::2 ] = 0.7
176-
177- for perms in itertools .permutations (range (4 )):
178- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
179- expected_Y = np_call (dpt .asnumpy (U ))
180- for ord in ["C" , "F" , "A" , "K" ]:
181- Y = dpt_call (U , order = ord )
182- tol = 8 * max (
183- dpt .finfo (Y .dtype ).resolution ,
184- np .finfo (expected_Y .dtype ).resolution ,
185- )
186- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
187-
188-
189- @pytest .mark .parametrize ("callable" , _dpt_funcs )
190- @pytest .mark .parametrize ("dtype" , _all_dtypes )
191- def test_trig_error_dtype (callable , dtype ):
192- q = get_queue_or_skip ()
193- skip_if_dtype_not_supported (dtype , q )
194-
195- x = dpt .zeros (5 , dtype = dtype )
196- y = dpt .empty_like (x , dtype = "int16" )
197- with pytest .raises (ValueError ) as excinfo :
198- callable (x , out = y )
199- assert re .match ("Output array of type.*is needed" , str (excinfo .value ))
200-
201-
202118@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
203119@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
204120def test_trig_real_strided (np_call , dpt_call , dtype ):
@@ -298,47 +214,3 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
298214 tol = 8 * dpt .finfo (dtype ).resolution
299215 Y = dpt_call (yf )
300216 assert_allclose (dpt .asnumpy (Y ), Y_np , atol = tol , rtol = tol )
301-
302-
303- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
304- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
305- def test_trig_complex_special_cases_conj_property (np_call , dpt_call , dtype ):
306- q = get_queue_or_skip ()
307- skip_if_dtype_not_supported (dtype , q )
308-
309- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
310- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
311-
312- Xc_np = np .array (xc , dtype = dtype )
313- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
314-
315- tol = 50 * dpt .finfo (dtype ).resolution
316- Y = dpt_call (Xc )
317- Yc = dpt_call (dpt .conj (Xc ))
318-
319- dpt .allclose (Y , dpt .conj (Yc ), atol = tol , rtol = tol )
320-
321-
322- @pytest .mark .skipif (
323- os .name != "posix" , reason = "Known to fail on Windows due to bug in NumPy"
324- )
325- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
326- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
327- def test_trig_complex_special_cases (np_call , dpt_call , dtype ):
328-
329- q = get_queue_or_skip ()
330- skip_if_dtype_not_supported (dtype , q )
331-
332- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
333- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
334-
335- Xc_np = np .array (xc , dtype = dtype )
336- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
337-
338- with np .errstate (all = "ignore" ):
339- Ynp = np_call (Xc_np )
340-
341- tol = 50 * dpt .finfo (dtype ).resolution
342- Y = dpt_call (Xc )
343- assert_allclose (dpt .asnumpy (dpt .real (Y )), np .real (Ynp ), atol = tol , rtol = tol )
344- assert_allclose (dpt .asnumpy (dpt .imag (Y )), np .imag (Ynp ), atol = tol , rtol = tol )
0 commit comments