@@ -127,16 +127,24 @@ def test_projection_complex(dtype):
127127 q = get_queue_or_skip ()
128128 skip_if_dtype_not_supported (dtype , q )
129129
130- X = [complex (1 , 2 ), complex (dpt .inf , - 1 ), complex (0 , - dpt .inf )]
131- Y = [complex (1 , 2 ), complex (dpt .inf , - 0 ), complex (dpt .inf , - 0 )]
130+ X = [
131+ complex (1 , 2 ),
132+ complex (dpt .inf , - 1 ),
133+ complex (0 , - dpt .inf ),
134+ complex (- dpt .inf , dpt .nan ),
135+ ]
136+ Y = [
137+ complex (1 , 2 ),
138+ complex (np .inf , - 0.0 ),
139+ complex (np .inf , - 0.0 ),
140+ complex (np .inf , 0.0 ),
141+ ]
132142
133143 Xf = dpt .asarray (X , dtype = dtype , sycl_queue = q )
134- Yf = dpt . asarray (Y , dtype = dtype , sycl_queue = q )
144+ Yf = np . array (Y , dtype = dtype )
135145
136146 tol = 8 * dpt .finfo (Xf .dtype ).resolution
137- assert_allclose (
138- dpt .asnumpy (dpt .proj (Xf )), dpt .asnumpy (Yf ), atol = tol , rtol = tol
139- )
147+ assert_allclose (dpt .asnumpy (dpt .proj (Xf )), Yf , atol = tol , rtol = tol )
140148
141149
142150@pytest .mark .parametrize ("dtype" , _all_dtypes )
@@ -146,19 +154,17 @@ def test_projection(dtype):
146154
147155 Xf = dpt .asarray (1 , dtype = dtype , sycl_queue = q )
148156 out_dtype = dpt .proj (Xf ).dtype
149- Yf = dpt . asarray (complex (1 , 0 ), dtype = out_dtype , sycl_queue = q )
157+ Yf = np . array (complex (1 , 0 ), dtype = out_dtype )
150158
151159 tol = 8 * dpt .finfo (Yf .dtype ).resolution
152- assert_allclose (
153- dpt .asnumpy (dpt .proj (Xf )), dpt .asnumpy (Yf ), atol = tol , rtol = tol
154- )
160+ assert_allclose (dpt .asnumpy (dpt .proj (Xf )), Yf , atol = tol , rtol = tol )
155161
156162
157163@pytest .mark .parametrize (
158164 "np_call, dpt_call" ,
159165 [(np .real , dpt .real ), (np .imag , dpt .imag ), (np .conj , dpt .conj )],
160166)
161- @pytest .mark .parametrize ("dtype" , ["f " , "d " ])
167+ @pytest .mark .parametrize ("dtype" , ["f4 " , "f8 " ])
162168@pytest .mark .parametrize ("stride" , [- 1 , 1 , 2 , 4 , 5 ])
163169def test_complex_strided (np_call , dpt_call , dtype , stride ):
164170 q = get_queue_or_skip ()
@@ -176,7 +182,7 @@ def test_complex_strided(np_call, dpt_call, dtype, stride):
176182 assert_allclose (y , dpt .asnumpy (z ), atol = tol , rtol = tol )
177183
178184
179- @pytest .mark .parametrize ("dtype" , ["e " , "f " , "d " ])
185+ @pytest .mark .parametrize ("dtype" , ["f2 " , "f4 " , "f8 " ])
180186def test_complex_special_cases (dtype ):
181187 q = get_queue_or_skip ()
182188 skip_if_dtype_not_supported (dtype , q )
0 commit comments