|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | | -import ctypes |
18 | | - |
19 | 17 | import numpy as np |
20 | 18 | import pytest |
21 | 19 |
|
22 | | -import dpctl |
23 | 20 | import dpctl.tensor as dpt |
24 | 21 | from dpctl.tensor._type_utils import _can_cast |
25 | 22 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported |
26 | 23 |
|
27 | | -from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types |
| 24 | +from .utils import _compare_dtypes, _no_complex_dtypes |
28 | 25 |
|
29 | 26 |
|
30 | 27 | @pytest.mark.parametrize("op1_dtype", _no_complex_dtypes) |
@@ -61,59 +58,6 @@ def test_remainder_dtype_matrix(op1_dtype, op2_dtype): |
61 | 58 | assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() |
62 | 59 |
|
63 | 60 |
|
64 | | -@pytest.mark.parametrize("op1_usm_type", _usm_types) |
65 | | -@pytest.mark.parametrize("op2_usm_type", _usm_types) |
66 | | -def test_remainder_usm_type_matrix(op1_usm_type, op2_usm_type): |
67 | | - get_queue_or_skip() |
68 | | - |
69 | | - sz = 128 |
70 | | - ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) |
71 | | - ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) |
72 | | - |
73 | | - r = dpt.remainder(ar1, ar2) |
74 | | - assert isinstance(r, dpt.usm_ndarray) |
75 | | - expected_usm_type = dpctl.utils.get_coerced_usm_type( |
76 | | - (op1_usm_type, op2_usm_type) |
77 | | - ) |
78 | | - assert r.usm_type == expected_usm_type |
79 | | - |
80 | | - |
81 | | -def test_remainder_order(): |
82 | | - get_queue_or_skip() |
83 | | - |
84 | | - ar1 = dpt.ones((20, 20), dtype="i4", order="C") |
85 | | - ar2 = dpt.ones((20, 20), dtype="i4", order="C") |
86 | | - r1 = dpt.remainder(ar1, ar2, order="C") |
87 | | - assert r1.flags.c_contiguous |
88 | | - r2 = dpt.remainder(ar1, ar2, order="F") |
89 | | - assert r2.flags.f_contiguous |
90 | | - r3 = dpt.remainder(ar1, ar2, order="A") |
91 | | - assert r3.flags.c_contiguous |
92 | | - r4 = dpt.remainder(ar1, ar2, order="K") |
93 | | - assert r4.flags.c_contiguous |
94 | | - |
95 | | - ar1 = dpt.ones((20, 20), dtype="i4", order="F") |
96 | | - ar2 = dpt.ones((20, 20), dtype="i4", order="F") |
97 | | - r1 = dpt.remainder(ar1, ar2, order="C") |
98 | | - assert r1.flags.c_contiguous |
99 | | - r2 = dpt.remainder(ar1, ar2, order="F") |
100 | | - assert r2.flags.f_contiguous |
101 | | - r3 = dpt.remainder(ar1, ar2, order="A") |
102 | | - assert r3.flags.f_contiguous |
103 | | - r4 = dpt.remainder(ar1, ar2, order="K") |
104 | | - assert r4.flags.f_contiguous |
105 | | - |
106 | | - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] |
107 | | - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] |
108 | | - r4 = dpt.remainder(ar1, ar2, order="K") |
109 | | - assert r4.strides == (20, -1) |
110 | | - |
111 | | - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT |
112 | | - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT |
113 | | - r4 = dpt.remainder(ar1, ar2, order="K") |
114 | | - assert r4.strides == (-1, 20) |
115 | | - |
116 | | - |
117 | 61 | @pytest.mark.parametrize("dt", _no_complex_dtypes[1:8:2]) |
118 | 62 | def test_remainder_negative_integers(dt): |
119 | 63 | q = get_queue_or_skip() |
@@ -189,38 +133,6 @@ def test_remainder_special_cases(): |
189 | 133 | np.allclose(dpt.asnumpy(res), np.remainder(x_np, y_np)) |
190 | 134 |
|
191 | 135 |
|
192 | | -@pytest.mark.parametrize("arr_dt", _no_complex_dtypes) |
193 | | -def test_remainder_python_scalar(arr_dt): |
194 | | - q = get_queue_or_skip() |
195 | | - skip_if_dtype_not_supported(arr_dt, q) |
196 | | - |
197 | | - X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q) |
198 | | - py_ones = ( |
199 | | - bool(1), |
200 | | - int(1), |
201 | | - float(1), |
202 | | - np.float32(1), |
203 | | - ctypes.c_int(1), |
204 | | - ) |
205 | | - for sc in py_ones: |
206 | | - R = dpt.remainder(X, sc) |
207 | | - assert isinstance(R, dpt.usm_ndarray) |
208 | | - R = dpt.remainder(sc, X) |
209 | | - assert isinstance(R, dpt.usm_ndarray) |
210 | | - |
211 | | - |
212 | | -@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:]) |
213 | | -def test_remainder_inplace_python_scalar(dtype): |
214 | | - q = get_queue_or_skip() |
215 | | - skip_if_dtype_not_supported(dtype, q) |
216 | | - X = dpt.ones((10, 10), dtype=dtype, sycl_queue=q) |
217 | | - dt_kind = X.dtype.kind |
218 | | - if dt_kind in "ui": |
219 | | - X %= int(1) |
220 | | - elif dt_kind == "f": |
221 | | - X %= float(1) |
222 | | - |
223 | | - |
224 | 136 | @pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:]) |
225 | 137 | @pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:]) |
226 | 138 | def test_remainder_inplace_dtype_matrix(op1_dtype, op2_dtype): |
|
0 commit comments