Skip to content

Commit ab136a5

Browse files
committed
remove redundant sign function tests
1 parent 85b0f4c commit ab136a5

File tree

2 files changed

+3
-87
lines changed

2 files changed

+3
-87
lines changed

dpctl/tests/elementwise/test_sign.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import dpctl.tensor as dpt
2323
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

25-
from .utils import _all_dtypes, _no_complex_dtypes, _usm_types
25+
from .utils import _all_dtypes, _no_complex_dtypes
2626

2727

2828
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
@@ -39,47 +39,6 @@ def test_sign_out_type(dtype):
3939
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.sign(X)))
4040

4141

42-
@pytest.mark.parametrize("usm_type", _usm_types)
43-
def test_sign_usm_type(usm_type):
44-
q = get_queue_or_skip()
45-
46-
arg_dt = np.dtype("i4")
47-
input_shape = (10, 10, 10, 10)
48-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
49-
X[..., 0::2] = 1
50-
X[..., 1::2] = 0
51-
52-
Y = dpt.sign(X)
53-
assert Y.usm_type == X.usm_type
54-
assert Y.sycl_queue == X.sycl_queue
55-
assert Y.flags.c_contiguous
56-
57-
expected_Y = dpt.asnumpy(X)
58-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
59-
60-
61-
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
62-
def test_sign_order(dtype):
63-
q = get_queue_or_skip()
64-
skip_if_dtype_not_supported(dtype, q)
65-
66-
arg_dt = np.dtype(dtype)
67-
expected_dt = np.sign(np.ones(tuple(), dtype=arg_dt)).dtype
68-
input_shape = (10, 10, 10, 10)
69-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
70-
X[..., 0::2] = 1
71-
X[..., 1::2] = 0
72-
73-
for perms in itertools.permutations(range(4)):
74-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
75-
expected_Y = np.ones(U.shape, dtype=expected_dt)
76-
expected_Y[..., 1::2] = 0
77-
expected_Y = np.transpose(expected_Y, perms)
78-
for ord in ["C", "F", "A", "K"]:
79-
Y = dpt.sign(U, order=ord)
80-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
81-
82-
8342
@pytest.mark.parametrize("dtype", ["c8", "c16"])
8443
def test_sign_complex(dtype):
8544
q = get_queue_or_skip()

dpctl/tests/elementwise/test_signbit.py

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
25-
def test_signbit_out_type_contig(dtype):
25+
def test_signbit_out_type(dtype):
2626
q = get_queue_or_skip()
2727
skip_if_dtype_not_supported(dtype, q)
2828

@@ -39,24 +39,7 @@ def test_signbit_out_type_contig(dtype):
3939

4040

4141
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
42-
def test_signbit_out_type_strided(dtype):
43-
q = get_queue_or_skip()
44-
skip_if_dtype_not_supported(dtype, q)
45-
46-
arg_dt = np.dtype(dtype)
47-
x = dpt.linspace(1, 10, num=256, dtype=arg_dt)
48-
sb = dpt.signbit(x[::-3])
49-
assert sb.dtype == dpt.bool
50-
51-
assert not dpt.any(sb)
52-
53-
x2 = dpt.linspace(-10, -1, num=256, dtype=arg_dt)
54-
sb2 = dpt.signbit(x2[::-3])
55-
assert dpt.all(sb2)
56-
57-
58-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
59-
def test_signbit_special_cases_contig(dtype):
42+
def test_signbit_special_cases(dtype):
6043
q = get_queue_or_skip()
6144
skip_if_dtype_not_supported(dtype, q)
6245

@@ -80,29 +63,3 @@ def test_signbit_special_cases_contig(dtype):
8063
)
8164

8265
assert dpt.all(dpt.equal(actual, expected))
83-
84-
85-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
86-
def test_signbit_special_cases_strided(dtype):
87-
q = get_queue_or_skip()
88-
skip_if_dtype_not_supported(dtype, q)
89-
90-
arg_dt = np.dtype(dtype)
91-
x1 = dpt.full(63, -dpt.inf, dtype=arg_dt)
92-
x2 = dpt.full(63, -0.0, dtype=arg_dt)
93-
x3 = dpt.full(63, 0.0, dtype=arg_dt)
94-
x4 = dpt.full(63, dpt.inf, dtype=arg_dt)
95-
96-
x = dpt.concat((x1, x2, x3, x4))
97-
actual = dpt.signbit(x[::-1])
98-
99-
expected = dpt.concat(
100-
(
101-
dpt.full(x4.size, False),
102-
dpt.full(x3.size, False),
103-
dpt.full(x2.size, True),
104-
dpt.full(x1.size, True),
105-
)
106-
)
107-
108-
assert dpt.all(dpt.equal(actual, expected))

0 commit comments

Comments
 (0)