Skip to content

Commit 3911bd3

Browse files
committed
remove redundant tests in test_complex
1 parent c652167 commit 3911bd3

File tree

1 file changed

+1
-82
lines changed

1 file changed

+1
-82
lines changed

dpctl/tests/elementwise/test_complex.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import dpctl.tensor as dpt
2525
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2626

27-
from .utils import _all_dtypes, _map_to_device_dtype, _usm_types
27+
from .utils import _all_dtypes, _map_to_device_dtype
2828

2929

3030
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -73,56 +73,6 @@ def test_complex_output(np_call, dpt_call, dtype):
7373
assert_allclose(dpt.asnumpy(Z), np_call(Xnp), atol=tol, rtol=tol)
7474

7575

76-
@pytest.mark.parametrize(
77-
"np_call, dpt_call",
78-
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
79-
)
80-
@pytest.mark.parametrize("usm_type", _usm_types)
81-
def test_complex_usm_type(np_call, dpt_call, usm_type):
82-
q = get_queue_or_skip()
83-
84-
arg_dt = np.dtype("c8")
85-
input_shape = (10, 10, 10, 10)
86-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
87-
X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3
88-
X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6
89-
90-
Y = dpt_call(X)
91-
assert Y.usm_type == X.usm_type
92-
assert Y.sycl_queue == X.sycl_queue
93-
assert Y.flags.c_contiguous
94-
95-
expected_Y = np.empty(input_shape, dtype=arg_dt)
96-
expected_Y[..., 0::2] = np_call(np.complex64(np.pi / 6 + 1j * np.pi / 3))
97-
expected_Y[..., 1::2] = np_call(np.complex64(np.pi / 3 + 1j * np.pi / 6))
98-
tol = 8 * dpt.finfo(Y.dtype).resolution
99-
100-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
101-
102-
103-
@pytest.mark.parametrize(
104-
"np_call, dpt_call",
105-
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
106-
)
107-
@pytest.mark.parametrize("dtype", _all_dtypes)
108-
def test_complex_order(np_call, dpt_call, dtype):
109-
q = get_queue_or_skip()
110-
skip_if_dtype_not_supported(dtype, q)
111-
112-
arg_dt = np.dtype(dtype)
113-
input_shape = (10, 10, 10, 10)
114-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
115-
X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3
116-
X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6
117-
118-
for perms in itertools.permutations(range(4)):
119-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
120-
expected_Y = np_call(dpt.asnumpy(U))
121-
for ord in ["C", "F", "A", "K"]:
122-
Y = dpt_call(U, order=ord)
123-
assert_allclose(dpt.asnumpy(Y), expected_Y)
124-
125-
12676
@pytest.mark.parametrize("dtype", ["c8", "c16"])
12777
def test_projection_complex(dtype):
12878
q = get_queue_or_skip()
@@ -161,37 +111,6 @@ def test_projection(dtype):
161111
assert_allclose(dpt.asnumpy(dpt.proj(Xf)), Yf, atol=tol, rtol=tol)
162112

163113

164-
@pytest.mark.parametrize(
165-
"np_call, dpt_call",
166-
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
167-
)
168-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
169-
def test_complex_strided(np_call, dpt_call, dtype):
170-
q = get_queue_or_skip()
171-
skip_if_dtype_not_supported(dtype, q)
172-
173-
np.random.seed(42)
174-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
175-
sizes = [2, 4, 6, 8, 9, 24, 72]
176-
tol = 8 * dpt.finfo(dtype).resolution
177-
178-
low = -1000.0
179-
high = 1000.0
180-
for ii in sizes:
181-
x1 = np.random.uniform(low=low, high=high, size=ii)
182-
x2 = np.random.uniform(low=low, high=high, size=ii)
183-
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
184-
X = dpt.asarray(Xnp)
185-
Ynp = np_call(Xnp)
186-
for jj in strides:
187-
assert_allclose(
188-
dpt.asnumpy(dpt_call(X[::jj])),
189-
Ynp[::jj],
190-
atol=tol,
191-
rtol=tol,
192-
)
193-
194-
195114
@pytest.mark.parametrize("dtype", ["c8", "c16"])
196115
def test_complex_special_cases(dtype):
197116
q = get_queue_or_skip()

0 commit comments

Comments
 (0)