2929 _complex_fp_dtypes ,
3030 _map_to_device_dtype ,
3131 _real_fp_dtypes ,
32- _usm_types ,
3332)
3433
3534
@@ -45,7 +44,7 @@ def test_sqrt_out_type(dtype):
4544
4645
4746@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
48- def test_sqrt_output_contig (dtype ):
47+ def test_sqrt_basic (dtype ):
4948 q = get_queue_or_skip ()
5049 skip_if_dtype_not_supported (dtype , q )
5150
@@ -60,68 +59,6 @@ def test_sqrt_output_contig(dtype):
6059 assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
6160
6261
63- @pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
64- def test_sqrt_output_strided (dtype ):
65- q = get_queue_or_skip ()
66- skip_if_dtype_not_supported (dtype , q )
67-
68- n_seq = 2054
69-
70- X = dpt .linspace (0 , 13 , num = n_seq , dtype = dtype , sycl_queue = q )[::- 2 ]
71- Xnp = dpt .asnumpy (X )
72-
73- Y = dpt .sqrt (X )
74- tol = 8 * dpt .finfo (Y .dtype ).resolution
75-
76- assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
77-
78-
79- @pytest .mark .parametrize ("usm_type" , _usm_types )
80- def test_sqrt_usm_type (usm_type ):
81- q = get_queue_or_skip ()
82-
83- arg_dt = np .dtype ("f4" )
84- input_shape = (10 , 10 , 10 , 10 )
85- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
86- X [..., 0 ::2 ] = 16.0
87- X [..., 1 ::2 ] = 23.0
88-
89- Y = dpt .sqrt (X )
90- assert Y .usm_type == X .usm_type
91- assert Y .sycl_queue == X .sycl_queue
92- assert Y .flags .c_contiguous
93-
94- expected_Y = np .empty (input_shape , dtype = arg_dt )
95- expected_Y [..., 0 ::2 ] = np .sqrt (np .float32 (16.0 ))
96- expected_Y [..., 1 ::2 ] = np .sqrt (np .float32 (23.0 ))
97- tol = 8 * dpt .finfo (Y .dtype ).resolution
98-
99- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
100-
101-
102- @pytest .mark .parametrize ("dtype" , _all_dtypes )
103- def test_sqrt_order (dtype ):
104- q = get_queue_or_skip ()
105- skip_if_dtype_not_supported (dtype , q )
106-
107- arg_dt = np .dtype (dtype )
108- input_shape = (10 , 10 , 10 , 10 )
109- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
110- X [..., 0 ::2 ] = 16.0
111- X [..., 1 ::2 ] = 23.0
112-
113- for perms in itertools .permutations (range (4 )):
114- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
115- expected_Y = np .sqrt (dpt .asnumpy (U ))
116- for ord in ["C" , "F" , "A" , "K" ]:
117- Y = dpt .sqrt (U , order = ord )
118- tol = 8 * max (
119- dpt .finfo (Y .dtype ).resolution ,
120- np .finfo (expected_Y .dtype ).resolution ,
121- )
122- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
123-
124-
12562@pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
12663def test_sqrt_special_cases ():
12764 q = get_queue_or_skip ()
0 commit comments