|
24 | 24 | import dpctl.tensor as dpt |
25 | 25 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported |
26 | 26 |
|
27 | | -from .utils import _all_dtypes, _map_to_device_dtype, _usm_types |
| 27 | +from .utils import _all_dtypes, _map_to_device_dtype |
28 | 28 |
|
29 | 29 |
|
30 | 30 | @pytest.mark.parametrize("dtype", _all_dtypes) |
@@ -73,56 +73,6 @@ def test_complex_output(np_call, dpt_call, dtype): |
73 | 73 | assert_allclose(dpt.asnumpy(Z), np_call(Xnp), atol=tol, rtol=tol) |
74 | 74 |
|
75 | 75 |
|
76 | | -@pytest.mark.parametrize( |
77 | | - "np_call, dpt_call", |
78 | | - [(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)], |
79 | | -) |
80 | | -@pytest.mark.parametrize("usm_type", _usm_types) |
81 | | -def test_complex_usm_type(np_call, dpt_call, usm_type): |
82 | | - q = get_queue_or_skip() |
83 | | - |
84 | | - arg_dt = np.dtype("c8") |
85 | | - input_shape = (10, 10, 10, 10) |
86 | | - X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) |
87 | | - X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3 |
88 | | - X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6 |
89 | | - |
90 | | - Y = dpt_call(X) |
91 | | - assert Y.usm_type == X.usm_type |
92 | | - assert Y.sycl_queue == X.sycl_queue |
93 | | - assert Y.flags.c_contiguous |
94 | | - |
95 | | - expected_Y = np.empty(input_shape, dtype=arg_dt) |
96 | | - expected_Y[..., 0::2] = np_call(np.complex64(np.pi / 6 + 1j * np.pi / 3)) |
97 | | - expected_Y[..., 1::2] = np_call(np.complex64(np.pi / 3 + 1j * np.pi / 6)) |
98 | | - tol = 8 * dpt.finfo(Y.dtype).resolution |
99 | | - |
100 | | - assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) |
101 | | - |
102 | | - |
103 | | -@pytest.mark.parametrize( |
104 | | - "np_call, dpt_call", |
105 | | - [(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)], |
106 | | -) |
107 | | -@pytest.mark.parametrize("dtype", _all_dtypes) |
108 | | -def test_complex_order(np_call, dpt_call, dtype): |
109 | | - q = get_queue_or_skip() |
110 | | - skip_if_dtype_not_supported(dtype, q) |
111 | | - |
112 | | - arg_dt = np.dtype(dtype) |
113 | | - input_shape = (10, 10, 10, 10) |
114 | | - X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) |
115 | | - X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3 |
116 | | - X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6 |
117 | | - |
118 | | - for perms in itertools.permutations(range(4)): |
119 | | - U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) |
120 | | - expected_Y = np_call(dpt.asnumpy(U)) |
121 | | - for ord in ["C", "F", "A", "K"]: |
122 | | - Y = dpt_call(U, order=ord) |
123 | | - assert_allclose(dpt.asnumpy(Y), expected_Y) |
124 | | - |
125 | | - |
126 | 76 | @pytest.mark.parametrize("dtype", ["c8", "c16"]) |
127 | 77 | def test_projection_complex(dtype): |
128 | 78 | q = get_queue_or_skip() |
@@ -161,37 +111,6 @@ def test_projection(dtype): |
161 | 111 | assert_allclose(dpt.asnumpy(dpt.proj(Xf)), Yf, atol=tol, rtol=tol) |
162 | 112 |
|
163 | 113 |
|
164 | | -@pytest.mark.parametrize( |
165 | | - "np_call, dpt_call", |
166 | | - [(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)], |
167 | | -) |
168 | | -@pytest.mark.parametrize("dtype", ["c8", "c16"]) |
169 | | -def test_complex_strided(np_call, dpt_call, dtype): |
170 | | - q = get_queue_or_skip() |
171 | | - skip_if_dtype_not_supported(dtype, q) |
172 | | - |
173 | | - np.random.seed(42) |
174 | | - strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4]) |
175 | | - sizes = [2, 4, 6, 8, 9, 24, 72] |
176 | | - tol = 8 * dpt.finfo(dtype).resolution |
177 | | - |
178 | | - low = -1000.0 |
179 | | - high = 1000.0 |
180 | | - for ii in sizes: |
181 | | - x1 = np.random.uniform(low=low, high=high, size=ii) |
182 | | - x2 = np.random.uniform(low=low, high=high, size=ii) |
183 | | - Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype) |
184 | | - X = dpt.asarray(Xnp) |
185 | | - Ynp = np_call(Xnp) |
186 | | - for jj in strides: |
187 | | - assert_allclose( |
188 | | - dpt.asnumpy(dpt_call(X[::jj])), |
189 | | - Ynp[::jj], |
190 | | - atol=tol, |
191 | | - rtol=tol, |
192 | | - ) |
193 | | - |
194 | | - |
195 | 114 | @pytest.mark.parametrize("dtype", ["c8", "c16"]) |
196 | 115 | def test_complex_special_cases(dtype): |
197 | 116 | q = get_queue_or_skip() |
|
0 commit comments