Skip to content

Commit 97296f1

Browse files
committed
remove unnecessary trig tests
1 parent 78d1246 commit 97296f1

File tree

2 files changed

+4
-148
lines changed

2 files changed

+4
-148
lines changed

dpctl/tests/elementwise/test_hyperbolic.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_hyper_out_type(np_call, dpt_call, dtype):
4848

4949
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
5050
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
51-
def test_hyper_real_contig(np_call, dpt_call, dtype):
51+
def test_hyper_basic(np_call, dpt_call, dtype):
5252
q = get_queue_or_skip()
5353
skip_if_dtype_not_supported(dtype, q)
5454

@@ -79,7 +79,7 @@ def test_hyper_real_contig(np_call, dpt_call, dtype):
7979

8080
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
8181
@pytest.mark.parametrize("dtype", ["c8", "c16"])
82-
def test_hyper_complex_contig(np_call, dpt_call, dtype):
82+
def test_hyper_complex(np_call, dpt_call, dtype):
8383
q = get_queue_or_skip()
8484
skip_if_dtype_not_supported(dtype, q)
8585

@@ -107,68 +107,6 @@ def test_hyper_complex_contig(np_call, dpt_call, dtype):
107107
)
108108

109109

110-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
111-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
112-
def test_hyper_real_strided(np_call, dpt_call, dtype):
113-
q = get_queue_or_skip()
114-
skip_if_dtype_not_supported(dtype, q)
115-
116-
np.random.seed(42)
117-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
118-
sizes = [2, 4, 6, 8, 9, 24, 72]
119-
tol = 8 * dpt.finfo(dtype).resolution
120-
121-
low = -10.0
122-
high = 10.0
123-
if np_call == np.arctanh:
124-
low = -0.9
125-
high = 0.9
126-
elif np_call == np.arccosh:
127-
low = 1.01
128-
high = 100.0
129-
130-
for ii in sizes:
131-
Xnp = np.random.uniform(low=low, high=high, size=ii)
132-
Xnp.astype(dtype)
133-
X = dpt.asarray(Xnp)
134-
Ynp = np_call(Xnp)
135-
for jj in strides:
136-
assert_allclose(
137-
dpt.asnumpy(dpt_call(X[::jj])),
138-
Ynp[::jj],
139-
atol=tol,
140-
rtol=tol,
141-
)
142-
143-
144-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
145-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
146-
def test_hyper_complex_strided(np_call, dpt_call, dtype):
147-
q = get_queue_or_skip()
148-
skip_if_dtype_not_supported(dtype, q)
149-
150-
np.random.seed(42)
151-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
152-
sizes = [2, 4, 6, 8, 9, 24, 72]
153-
tol = 50 * dpt.finfo(dtype).resolution
154-
155-
low = -8.0
156-
high = 8.0
157-
for ii in sizes:
158-
x1 = np.random.uniform(low=low, high=high, size=ii)
159-
x2 = np.random.uniform(low=low, high=high, size=ii)
160-
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
161-
X = dpt.asarray(Xnp)
162-
Ynp = np_call(Xnp)
163-
for jj in strides:
164-
assert_allclose(
165-
dpt.asnumpy(dpt_call(X[::jj])),
166-
Ynp[::jj],
167-
atol=tol,
168-
rtol=tol,
169-
)
170-
171-
172110
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
173111
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
174112
def test_hyper_real_special_cases(np_call, dpt_call, dtype):

dpctl/tests/elementwise/test_trigonometric.py

Lines changed: 2 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_trig_out_type(np_call, dpt_call, dtype):
4646

4747
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
4848
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
49-
def test_trig_real_contig(np_call, dpt_call, dtype):
49+
def test_trig_real_basic(np_call, dpt_call, dtype):
5050
q = get_queue_or_skip()
5151
skip_if_dtype_not_supported(dtype, q)
5252

@@ -79,7 +79,7 @@ def test_trig_real_contig(np_call, dpt_call, dtype):
7979

8080
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
8181
@pytest.mark.parametrize("dtype", ["c8", "c16"])
82-
def test_trig_complex_contig(np_call, dpt_call, dtype):
82+
def test_trig_complex_basic(np_call, dpt_call, dtype):
8383
q = get_queue_or_skip()
8484
skip_if_dtype_not_supported(dtype, q)
8585

@@ -115,88 +115,6 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
115115
assert_allclose(dpt.asnumpy(Z), expected, atol=tol, rtol=tol)
116116

117117

118-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
119-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
120-
def test_trig_real_strided(np_call, dpt_call, dtype):
121-
q = get_queue_or_skip()
122-
skip_if_dtype_not_supported(dtype, q)
123-
124-
np.random.seed(42)
125-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
126-
sizes = [2, 3, 4, 6, 8, 9, 24, 50, 72]
127-
tol = 8 * dpt.finfo(dtype).resolution
128-
129-
low = -100.0
130-
high = 100.0
131-
if np_call in [np.arccos, np.arcsin]:
132-
low = -1.0
133-
high = 1.0
134-
elif np_call in [np.tan]:
135-
low = -np.pi / 2 * (0.99)
136-
high = np.pi / 2 * (0.99)
137-
138-
for ii in sizes:
139-
Xnp = np.random.uniform(low=low, high=high, size=ii)
140-
Xnp.astype(dtype)
141-
X = dpt.asarray(Xnp)
142-
Ynp = np_call(Xnp)
143-
for jj in strides:
144-
assert_allclose(
145-
dpt.asnumpy(dpt_call(X[::jj])),
146-
Ynp[::jj],
147-
atol=tol,
148-
rtol=tol,
149-
)
150-
151-
152-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
153-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
154-
def test_trig_complex_strided(np_call, dpt_call, dtype):
155-
q = get_queue_or_skip()
156-
skip_if_dtype_not_supported(dtype, q)
157-
158-
np.random.seed(42)
159-
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
160-
sizes = [2, 4, 6, 8, 9, 24, 72]
161-
tol = 50 * dpt.finfo(dtype).resolution
162-
163-
low = -9.0
164-
high = 9.0
165-
while True:
166-
x1 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
167-
x2 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
168-
Xnp_all = np.array(
169-
[complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype
170-
)
171-
172-
# stay away from poles and branch lines
173-
modulus = np.abs(Xnp_all)
174-
sel = np.logical_or(
175-
modulus < 0.9,
176-
np.logical_and(
177-
modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05
178-
),
179-
)
180-
Xnp_all = Xnp_all[sel]
181-
if Xnp_all.size > sum(sizes):
182-
break
183-
184-
pos = 0
185-
for ii in sizes:
186-
pos = pos + ii
187-
Xnp = Xnp_all[:pos]
188-
Xnp = Xnp[-ii:]
189-
X = dpt.asarray(Xnp)
190-
Ynp = np_call(Xnp)
191-
for jj in strides:
192-
assert_allclose(
193-
dpt.asnumpy(dpt_call(X[::jj])),
194-
Ynp[::jj],
195-
atol=tol,
196-
rtol=tol,
197-
)
198-
199-
200118
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
201119
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
202120
def test_trig_real_special_cases(np_call, dpt_call, dtype):

0 commit comments

Comments
 (0)