Skip to content

Commit 711be50

Browse files
committed
remove redundant square and sqrt tests
1 parent 4e3b065 commit 711be50

File tree

2 files changed

+2
-107
lines changed

2 files changed

+2
-107
lines changed

dpctl/tests/elementwise/test_sqrt.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
_complex_fp_dtypes,
3030
_map_to_device_dtype,
3131
_real_fp_dtypes,
32-
_usm_types,
3332
)
3433

3534

@@ -45,7 +44,7 @@ def test_sqrt_out_type(dtype):
4544

4645

4746
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
48-
def test_sqrt_output_contig(dtype):
47+
def test_sqrt_basic(dtype):
4948
q = get_queue_or_skip()
5049
skip_if_dtype_not_supported(dtype, q)
5150

@@ -60,68 +59,6 @@ def test_sqrt_output_contig(dtype):
6059
assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
6160

6261

63-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
64-
def test_sqrt_output_strided(dtype):
65-
q = get_queue_or_skip()
66-
skip_if_dtype_not_supported(dtype, q)
67-
68-
n_seq = 2054
69-
70-
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
71-
Xnp = dpt.asnumpy(X)
72-
73-
Y = dpt.sqrt(X)
74-
tol = 8 * dpt.finfo(Y.dtype).resolution
75-
76-
assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
77-
78-
79-
@pytest.mark.parametrize("usm_type", _usm_types)
80-
def test_sqrt_usm_type(usm_type):
81-
q = get_queue_or_skip()
82-
83-
arg_dt = np.dtype("f4")
84-
input_shape = (10, 10, 10, 10)
85-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
86-
X[..., 0::2] = 16.0
87-
X[..., 1::2] = 23.0
88-
89-
Y = dpt.sqrt(X)
90-
assert Y.usm_type == X.usm_type
91-
assert Y.sycl_queue == X.sycl_queue
92-
assert Y.flags.c_contiguous
93-
94-
expected_Y = np.empty(input_shape, dtype=arg_dt)
95-
expected_Y[..., 0::2] = np.sqrt(np.float32(16.0))
96-
expected_Y[..., 1::2] = np.sqrt(np.float32(23.0))
97-
tol = 8 * dpt.finfo(Y.dtype).resolution
98-
99-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
100-
101-
102-
@pytest.mark.parametrize("dtype", _all_dtypes)
103-
def test_sqrt_order(dtype):
104-
q = get_queue_or_skip()
105-
skip_if_dtype_not_supported(dtype, q)
106-
107-
arg_dt = np.dtype(dtype)
108-
input_shape = (10, 10, 10, 10)
109-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
110-
X[..., 0::2] = 16.0
111-
X[..., 1::2] = 23.0
112-
113-
for perms in itertools.permutations(range(4)):
114-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
115-
expected_Y = np.sqrt(dpt.asnumpy(U))
116-
for ord in ["C", "F", "A", "K"]:
117-
Y = dpt.sqrt(U, order=ord)
118-
tol = 8 * max(
119-
dpt.finfo(Y.dtype).resolution,
120-
np.finfo(expected_Y.dtype).resolution,
121-
)
122-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
123-
124-
12562
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
12663
def test_sqrt_special_cases():
12764
q = get_queue_or_skip()

dpctl/tests/elementwise/test_square.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import itertools
18-
1917
import numpy as np
2018
import pytest
2119

2220
import dpctl.tensor as dpt
2321
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2422

25-
from .utils import _all_dtypes, _usm_types
23+
from .utils import _all_dtypes
2624

2725

2826
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
@@ -39,46 +37,6 @@ def test_square_out_type(dtype):
3937
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.square(X)))
4038

4139

42-
@pytest.mark.parametrize("usm_type", _usm_types)
43-
def test_square_usm_type(usm_type):
44-
q = get_queue_or_skip()
45-
46-
arg_dt = np.dtype("i4")
47-
input_shape = (10, 10, 10, 10)
48-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
49-
X[..., 0::2] = 1
50-
X[..., 1::2] = 0
51-
52-
Y = dpt.square(X)
53-
assert Y.usm_type == X.usm_type
54-
assert Y.sycl_queue == X.sycl_queue
55-
assert Y.flags.c_contiguous
56-
57-
expected_Y = dpt.asnumpy(X)
58-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
59-
60-
61-
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
62-
def test_square_order(dtype):
63-
q = get_queue_or_skip()
64-
skip_if_dtype_not_supported(dtype, q)
65-
66-
arg_dt = np.dtype(dtype)
67-
input_shape = (10, 10, 10, 10)
68-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
69-
X[..., 0::2] = 2
70-
X[..., 1::2] = 0
71-
72-
for perms in itertools.permutations(range(4)):
73-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
74-
expected_Y = np.full(U.shape, 4, dtype=U.dtype)
75-
expected_Y[..., 1::2] = 0
76-
expected_Y = np.transpose(expected_Y, perms)
77-
for ord in ["C", "F", "A", "K"]:
78-
Y = dpt.square(U, order=ord)
79-
assert np.allclose(dpt.asnumpy(Y), expected_Y)
80-
81-
8240
@pytest.mark.parametrize("dtype", ["c8", "c16"])
8341
def test_square_special_cases(dtype):
8442
q = get_queue_or_skip()

0 commit comments

Comments
 (0)