2323import dpctl .tensor as dpt
2424from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2525
26- from .utils import _all_dtypes , _map_to_device_dtype , _usm_types
26+ from .utils import _all_dtypes , _map_to_device_dtype
2727
2828
2929@pytest .mark .parametrize ("dtype" , _all_dtypes [1 :])
@@ -38,7 +38,7 @@ def test_round_out_type(dtype):
3838
3939
4040@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
41- def test_round_real_contig (dtype ):
41+ def test_round_real_basic (dtype ):
4242 q = get_queue_or_skip ()
4343 skip_if_dtype_not_supported (dtype , q )
4444
@@ -59,7 +59,7 @@ def test_round_real_contig(dtype):
5959
6060
6161@pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
62- def test_round_complex_contig (dtype ):
62+ def test_round_complex_basic (dtype ):
6363 q = get_queue_or_skip ()
6464 skip_if_dtype_not_supported (dtype , q )
6565
@@ -87,48 +87,6 @@ def test_round_complex_contig(dtype):
8787 )
8888
8989
90- @pytest .mark .parametrize ("usm_type" , _usm_types )
91- def test_round_usm_type (usm_type ):
92- q = get_queue_or_skip ()
93-
94- arg_dt = np .dtype ("f4" )
95- input_shape = (10 , 10 , 10 , 10 )
96- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
97- X [..., 0 ::2 ] = 16.2
98- X [..., 1 ::2 ] = 23.7
99-
100- Y = dpt .round (X )
101- assert Y .usm_type == X .usm_type
102- assert Y .sycl_queue == X .sycl_queue
103- assert Y .flags .c_contiguous
104-
105- expected_Y = np .empty (input_shape , dtype = arg_dt )
106- expected_Y [..., 0 ::2 ] = np .round (np .float32 (16.2 ))
107- expected_Y [..., 1 ::2 ] = np .round (np .float32 (23.7 ))
108- tol = 8 * dpt .finfo (Y .dtype ).resolution
109-
110- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
111-
112-
113- @pytest .mark .parametrize ("dtype" , _all_dtypes )
114- def test_round_order (dtype ):
115- q = get_queue_or_skip ()
116- skip_if_dtype_not_supported (dtype , q )
117-
118- arg_dt = np .dtype (dtype )
119- input_shape = (10 , 10 , 10 , 10 )
120- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
121- X [..., 0 ::2 ] = 8.8
122- X [..., 1 ::2 ] = 11.3
123-
124- for perms in itertools .permutations (range (4 )):
125- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
126- expected_Y = np .round (dpt .asnumpy (U ))
127- for ord in ["C" , "F" , "A" , "K" ]:
128- Y = dpt .round (U , order = ord )
129- assert_allclose (dpt .asnumpy (Y ), expected_Y )
130-
131-
13290@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
13391def test_round_real_special_cases (dtype ):
13492 q = get_queue_or_skip ()
@@ -145,57 +103,6 @@ def test_round_real_special_cases(dtype):
145103 assert_array_equal (np .signbit (Y ), np .signbit (Ynp ))
146104
147105
148- @pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
149- def test_round_real_strided (dtype ):
150- q = get_queue_or_skip ()
151- skip_if_dtype_not_supported (dtype , q )
152-
153- np .random .seed (42 )
154- strides = np .array ([- 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ])
155- sizes = [2 , 4 , 6 , 8 , 9 , 24 , 72 ]
156- tol = 8 * dpt .finfo (dtype ).resolution
157-
158- for ii in sizes :
159- Xnp = np .random .uniform (low = 0.01 , high = 88.1 , size = ii )
160- Xnp .astype (dtype )
161- X = dpt .asarray (Xnp )
162- Ynp = np .round (Xnp )
163- for jj in strides :
164- assert_allclose (
165- dpt .asnumpy (dpt .round (X [::jj ])),
166- Ynp [::jj ],
167- atol = tol ,
168- rtol = tol ,
169- )
170-
171-
172- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
173- def test_round_complex_strided (dtype ):
174- q = get_queue_or_skip ()
175- skip_if_dtype_not_supported (dtype , q )
176-
177- np .random .seed (42 )
178- strides = np .array ([- 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ])
179- sizes = [2 , 4 , 6 , 8 , 9 , 24 , 72 ]
180- tol = 8 * dpt .finfo (dtype ).resolution
181-
182- low = - 88.0
183- high = 88.0
184- for ii in sizes :
185- x1 = np .random .uniform (low = low , high = high , size = ii )
186- x2 = np .random .uniform (low = low , high = high , size = ii )
187- Xnp = np .array ([complex (v1 , v2 ) for v1 , v2 in zip (x1 , x2 )], dtype = dtype )
188- X = dpt .asarray (Xnp )
189- Ynp = np .round (Xnp )
190- for jj in strides :
191- assert_allclose (
192- dpt .asnumpy (dpt .round (X [::jj ])),
193- Ynp [::jj ],
194- atol = tol ,
195- rtol = tol ,
196- )
197-
198-
199106@pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
200107def test_round_complex_special_cases (dtype ):
201108 q = get_queue_or_skip ()
0 commit comments