2424import dpctl .tensor as dpt
2525from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2626
27- from .utils import _map_to_device_dtype , _no_complex_dtypes
27+ from .utils import _map_to_device_dtype , _real_value_dtypes
2828
2929_all_funcs = [(np .floor , dpt .floor ), (np .ceil , dpt .ceil ), (np .trunc , dpt .trunc )]
3030
3131
32- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
33- @pytest .mark .parametrize ("dtype" , _no_complex_dtypes )
34- def test_floor_ceil_trunc_out_type (np_call , dpt_call , dtype ):
32+ @pytest .mark .parametrize ("dpt_call" , [ dpt . floor , dpt . ceil , dpt . trunc ] )
33+ @pytest .mark .parametrize ("dtype" , _real_value_dtypes )
34+ def test_floor_ceil_trunc_out_type (dpt_call , dtype ):
3535 q = get_queue_or_skip ()
3636 skip_if_dtype_not_supported (dtype , q )
37- if dtype == "b1" :
38- skip_if_dtype_not_supported ("f2" , q )
3937
40- X = dpt . asarray ( 0.1 , dtype = dtype , sycl_queue = q )
41- expected_dtype = np_call ( np . array (0.1 , dtype = dtype )). dtype
42- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
38+ arg_dt = np . dtype ( dtype )
39+ X = dpt . asarray (0.1 , dtype = arg_dt , sycl_queue = q )
40+ expected_dtype = _map_to_device_dtype (arg_dt , q .sycl_device )
4341 assert dpt_call (X ).dtype == expected_dtype
4442
4543 X = dpt .asarray (0.1 , dtype = dtype , sycl_queue = q )
46- expected_dtype = np_call (np .array (0.1 , dtype = dtype )).dtype
47- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
44+ expected_dtype = _map_to_device_dtype (arg_dt , q .sycl_device )
4845 Y = dpt .empty_like (X , dtype = expected_dtype )
4946 dpt_call (X , out = Y )
5047 assert_allclose (dpt .asnumpy (dpt_call (X )), dpt .asnumpy (Y ))
@@ -73,12 +70,10 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):
7370
7471
7572@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
76- @pytest .mark .parametrize ("dtype" , _no_complex_dtypes )
73+ @pytest .mark .parametrize ("dtype" , _real_value_dtypes )
7774def test_floor_ceil_trunc_order (np_call , dpt_call , dtype ):
7875 q = get_queue_or_skip ()
7976 skip_if_dtype_not_supported (dtype , q )
80- if dtype == "b1" :
81- skip_if_dtype_not_supported ("f2" , q )
8277
8378 arg_dt = np .dtype (dtype )
8479 input_shape = (10 , 10 , 10 , 10 )
@@ -90,17 +85,12 @@ def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
9085 for perms in itertools .permutations (range (4 )):
9186 U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
9287 Y = dpt_call (U , order = ord )
93- with np .errstate (all = "ignore" ):
94- expected_Y = np_call (dpt .asnumpy (U ))
95- tol = 8 * max (
96- dpt .finfo (Y .dtype ).resolution ,
97- np .finfo (expected_Y .dtype ).resolution ,
98- )
99- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
88+ expected_Y = np_call (dpt .asnumpy (U ))
89+ assert_allclose (dpt .asnumpy (Y ), expected_Y )
10090
10191
102- @pytest .mark .parametrize ("callable " , [dpt .floor , dpt .ceil , dpt .trunc ])
103- def test_floor_ceil_trunc_errors (callable ):
92+ @pytest .mark .parametrize ("dpt_call " , [dpt .floor , dpt .ceil , dpt .trunc ])
93+ def test_floor_ceil_trunc_errors (dpt_call ):
10494 get_queue_or_skip ()
10595 try :
10696 gpu_queue = dpctl .SyclQueue ("gpu" )
@@ -116,7 +106,7 @@ def test_floor_ceil_trunc_errors(callable):
116106 assert_raises_regex (
117107 TypeError ,
118108 "Input and output allocation queues are not compatible" ,
119- callable ,
109+ dpt_call ,
120110 x ,
121111 y ,
122112 )
@@ -126,41 +116,39 @@ def test_floor_ceil_trunc_errors(callable):
126116 assert_raises_regex (
127117 TypeError ,
128118 "The shape of input and output arrays are inconsistent" ,
129- callable ,
119+ dpt_call ,
130120 x ,
131121 y ,
132122 )
133123
134124 x = dpt .zeros (2 )
135125 y = x
136126 assert_raises_regex (
137- TypeError , "Input and output arrays have memory overlap" , callable , x , y
127+ TypeError , "Input and output arrays have memory overlap" , dpt_call , x , y
138128 )
139129
140130 x = dpt .zeros (2 , dtype = "float32" )
141131 y = np .empty_like (x )
142132 assert_raises_regex (
143- TypeError , "output array must be of usm_ndarray type" , callable , x , y
133+ TypeError , "output array must be of usm_ndarray type" , dpt_call , x , y
144134 )
145135
146136
147- @pytest .mark .parametrize ("callable " , [dpt .floor , dpt .ceil , dpt .trunc ])
148- @pytest .mark .parametrize ("dtype" , _no_complex_dtypes )
149- def test_floor_ceil_trunc_error_dtype (callable , dtype ):
137+ @pytest .mark .parametrize ("dpt_call " , [dpt .floor , dpt .ceil , dpt .trunc ])
138+ @pytest .mark .parametrize ("dtype" , _real_value_dtypes )
139+ def test_floor_ceil_trunc_error_dtype (dpt_call , dtype ):
150140 q = get_queue_or_skip ()
151141 skip_if_dtype_not_supported (dtype , q )
152- if dtype == "b1" :
153- skip_if_dtype_not_supported ("f2" , q )
154142
155143 x = dpt .zeros (5 , dtype = dtype )
156- y = dpt .empty_like (x , dtype = "int16 " )
144+ y = dpt .empty_like (x , dtype = "b1 " )
157145 assert_raises_regex (
158- TypeError , "Output array of type.*is needed" , callable , x , y
146+ TypeError , "Output array of type.*is needed" , dpt_call , x , y
159147 )
160148
161149
162150@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
163- @pytest .mark .parametrize ("dtype" , [ "f2" , "f4" , "f8" ] )
151+ @pytest .mark .parametrize ("dtype" , _real_value_dtypes )
164152def test_floor_ceil_trunc_contig (np_call , dpt_call , dtype ):
165153 q = get_queue_or_skip ()
166154 skip_if_dtype_not_supported (dtype , q )
@@ -172,29 +160,23 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
172160 X = dpt .asarray (np .repeat (Xnp , n_rep ), dtype = dtype , sycl_queue = q )
173161 Y = dpt_call (X )
174162
175- tol = 8 * dpt .finfo (Y .dtype ).resolution
176- assert_allclose (
177- dpt .asnumpy (Y ), np .repeat (np_call (Xnp ), n_rep ), atol = tol , rtol = tol
178- )
163+ assert_allclose (dpt .asnumpy (Y ), np .repeat (np_call (Xnp ), n_rep ))
179164
180165 Z = dpt .empty_like (X , dtype = dtype )
181166 dpt_call (X , out = Z )
182167
183- assert_allclose (
184- dpt .asnumpy (Z ), np .repeat (np_call (Xnp ), n_rep ), atol = tol , rtol = tol
185- )
168+ assert_allclose (dpt .asnumpy (Z ), np .repeat (np_call (Xnp ), n_rep ))
186169
187170
188171@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
189- @pytest .mark .parametrize ("dtype" , [ "f2" , "f4" , "f8" ] )
172+ @pytest .mark .parametrize ("dtype" , _real_value_dtypes )
190173def test_floor_ceil_trunc_strided (np_call , dpt_call , dtype ):
191174 q = get_queue_or_skip ()
192175 skip_if_dtype_not_supported (dtype , q )
193176
194177 np .random .seed (42 )
195178 strides = np .array ([- 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ])
196179 sizes = np .arange (2 , 100 )
197- tol = 8 * dpt .finfo (dtype ).resolution
198180
199181 for ii in sizes :
200182 Xnp = np .random .uniform (low = - 99.9 , high = 99.9 , size = ii )
@@ -205,8 +187,6 @@ def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
205187 assert_allclose (
206188 dpt .asnumpy (dpt_call (X [::jj ])),
207189 Ynp [::jj ],
208- atol = tol ,
209- rtol = tol ,
210190 )
211191
212192
@@ -221,8 +201,7 @@ def test_floor_ceil_trunc_special_cases(np_call, dpt_call, dtype):
221201 xf = np .array (x , dtype = dtype )
222202 yf = dpt .asarray (xf , dtype = dtype , sycl_queue = q )
223203
224- with np .errstate (all = "ignore" ):
225- Y_np = np_call (xf )
204+ Y_np = np_call (xf )
226205
227206 tol = 8 * dpt .finfo (dtype ).resolution
228207 assert_allclose (dpt .asnumpy (dpt_call (yf )), Y_np , atol = tol , rtol = tol )
0 commit comments