Skip to content

Commit 7a91fea

Browse files
committed
remove unnecessary rounding function tests
1 parent 06b577e commit 7a91fea

File tree

2 files changed

+5
-177
lines changed

2 files changed

+5
-177
lines changed

dpctl/tests/elementwise/test_floor_ceil_trunc.py

Lines changed: 2 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import itertools
18-
import re
19-
2017
import numpy as np
2118
import pytest
2219
from numpy.testing import assert_allclose, assert_array_equal
2320

2421
import dpctl.tensor as dpt
2522
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2623

27-
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes
24+
from .utils import _map_to_device_dtype, _no_complex_dtypes
2825

2926
_all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)]
3027

@@ -47,63 +44,9 @@ def test_floor_ceil_trunc_out_type(dpt_call, dtype):
4744
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
4845

4946

50-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
51-
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
52-
def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):
53-
q = get_queue_or_skip()
54-
55-
arg_dt = np.dtype("f4")
56-
input_shape = (10, 10, 10, 10)
57-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
58-
X[..., 0::2] = -0.4
59-
X[..., 1::2] = 0.7
60-
61-
Y = dpt_call(X)
62-
assert Y.usm_type == X.usm_type
63-
assert Y.sycl_queue == X.sycl_queue
64-
assert Y.flags.c_contiguous
65-
66-
expected_Y = np_call(dpt.asnumpy(X))
67-
tol = 8 * dpt.finfo(Y.dtype).resolution
68-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
69-
70-
7147
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
7248
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
73-
def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
74-
q = get_queue_or_skip()
75-
skip_if_dtype_not_supported(dtype, q)
76-
77-
arg_dt = np.dtype(dtype)
78-
input_shape = (4, 4, 4, 4)
79-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
80-
X[..., 0::2] = -0.4
81-
X[..., 1::2] = 0.7
82-
83-
for perms in itertools.permutations(range(4)):
84-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
85-
expected_Y = np_call(dpt.asnumpy(U))
86-
for ord in ["C", "F", "A", "K"]:
87-
Y = dpt_call(U, order=ord)
88-
assert_allclose(dpt.asnumpy(Y), expected_Y)
89-
90-
91-
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
92-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
93-
def test_floor_ceil_trunc_error_dtype(dpt_call, dtype):
94-
q = get_queue_or_skip()
95-
skip_if_dtype_not_supported(dtype, q)
96-
97-
x = dpt.zeros(5, dtype=dtype)
98-
y = dpt.empty_like(x, dtype="b1")
99-
with pytest.raises(ValueError) as excinfo:
100-
dpt_call(x, out=y)
101-
assert re.match("Output array of type.*is needed", str(excinfo.value))
102-
103-
104-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
105-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
106-
def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
49+
def test_floor_ceil_trunc_basic(np_call, dpt_call, dtype):
10750
q = get_queue_or_skip()
10851
skip_if_dtype_not_supported(dtype, q)
10952

@@ -122,28 +65,6 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
12265
assert_allclose(dpt.asnumpy(Z), np.repeat(np_call(Xnp), n_rep))
12366

12467

125-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
126-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
127-
def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
128-
q = get_queue_or_skip()
129-
skip_if_dtype_not_supported(dtype, q)
130-
131-
np.random.seed(42)
132-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
133-
sizes = [2, 4, 6, 8, 24, 32, 72]
134-
135-
for ii in sizes:
136-
Xnp = np.random.uniform(low=-99.9, high=99.9, size=ii)
137-
Xnp.astype(dtype)
138-
X = dpt.asarray(Xnp)
139-
Ynp = np_call(Xnp)
140-
for jj in strides:
141-
assert_allclose(
142-
dpt.asnumpy(dpt_call(X[::jj])),
143-
Ynp[::jj],
144-
)
145-
146-
14768
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
14869
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
14970
def test_floor_ceil_trunc_special_cases(np_call, dpt_call, dtype):

dpctl/tests/elementwise/test_round.py

Lines changed: 3 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2525

26-
from .utils import _all_dtypes, _map_to_device_dtype, _usm_types
26+
from .utils import _all_dtypes, _map_to_device_dtype
2727

2828

2929
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
@@ -38,7 +38,7 @@ def test_round_out_type(dtype):
3838

3939

4040
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
41-
def test_round_real_contig(dtype):
41+
def test_round_real_basic(dtype):
4242
q = get_queue_or_skip()
4343
skip_if_dtype_not_supported(dtype, q)
4444

@@ -59,7 +59,7 @@ def test_round_real_contig(dtype):
5959

6060

6161
@pytest.mark.parametrize("dtype", ["c8", "c16"])
62-
def test_round_complex_contig(dtype):
62+
def test_round_complex_basic(dtype):
6363
q = get_queue_or_skip()
6464
skip_if_dtype_not_supported(dtype, q)
6565

@@ -87,48 +87,6 @@ def test_round_complex_contig(dtype):
8787
)
8888

8989

90-
@pytest.mark.parametrize("usm_type", _usm_types)
91-
def test_round_usm_type(usm_type):
92-
q = get_queue_or_skip()
93-
94-
arg_dt = np.dtype("f4")
95-
input_shape = (10, 10, 10, 10)
96-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
97-
X[..., 0::2] = 16.2
98-
X[..., 1::2] = 23.7
99-
100-
Y = dpt.round(X)
101-
assert Y.usm_type == X.usm_type
102-
assert Y.sycl_queue == X.sycl_queue
103-
assert Y.flags.c_contiguous
104-
105-
expected_Y = np.empty(input_shape, dtype=arg_dt)
106-
expected_Y[..., 0::2] = np.round(np.float32(16.2))
107-
expected_Y[..., 1::2] = np.round(np.float32(23.7))
108-
tol = 8 * dpt.finfo(Y.dtype).resolution
109-
110-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
111-
112-
113-
@pytest.mark.parametrize("dtype", _all_dtypes)
114-
def test_round_order(dtype):
115-
q = get_queue_or_skip()
116-
skip_if_dtype_not_supported(dtype, q)
117-
118-
arg_dt = np.dtype(dtype)
119-
input_shape = (10, 10, 10, 10)
120-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
121-
X[..., 0::2] = 8.8
122-
X[..., 1::2] = 11.3
123-
124-
for perms in itertools.permutations(range(4)):
125-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
126-
expected_Y = np.round(dpt.asnumpy(U))
127-
for ord in ["C", "F", "A", "K"]:
128-
Y = dpt.round(U, order=ord)
129-
assert_allclose(dpt.asnumpy(Y), expected_Y)
130-
131-
13290
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
13391
def test_round_real_special_cases(dtype):
13492
q = get_queue_or_skip()
@@ -145,57 +103,6 @@ def test_round_real_special_cases(dtype):
145103
assert_array_equal(np.signbit(Y), np.signbit(Ynp))
146104

147105

148-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
149-
def test_round_real_strided(dtype):
150-
q = get_queue_or_skip()
151-
skip_if_dtype_not_supported(dtype, q)
152-
153-
np.random.seed(42)
154-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
155-
sizes = [2, 4, 6, 8, 9, 24, 72]
156-
tol = 8 * dpt.finfo(dtype).resolution
157-
158-
for ii in sizes:
159-
Xnp = np.random.uniform(low=0.01, high=88.1, size=ii)
160-
Xnp.astype(dtype)
161-
X = dpt.asarray(Xnp)
162-
Ynp = np.round(Xnp)
163-
for jj in strides:
164-
assert_allclose(
165-
dpt.asnumpy(dpt.round(X[::jj])),
166-
Ynp[::jj],
167-
atol=tol,
168-
rtol=tol,
169-
)
170-
171-
172-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
173-
def test_round_complex_strided(dtype):
174-
q = get_queue_or_skip()
175-
skip_if_dtype_not_supported(dtype, q)
176-
177-
np.random.seed(42)
178-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
179-
sizes = [2, 4, 6, 8, 9, 24, 72]
180-
tol = 8 * dpt.finfo(dtype).resolution
181-
182-
low = -88.0
183-
high = 88.0
184-
for ii in sizes:
185-
x1 = np.random.uniform(low=low, high=high, size=ii)
186-
x2 = np.random.uniform(low=low, high=high, size=ii)
187-
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
188-
X = dpt.asarray(Xnp)
189-
Ynp = np.round(Xnp)
190-
for jj in strides:
191-
assert_allclose(
192-
dpt.asnumpy(dpt.round(X[::jj])),
193-
Ynp[::jj],
194-
atol=tol,
195-
rtol=tol,
196-
)
197-
198-
199106
@pytest.mark.parametrize("dtype", ["c8", "c16"])
200107
def test_round_complex_special_cases(dtype):
201108
q = get_queue_or_skip()

0 commit comments

Comments
 (0)