33import arrayfire_wrapper .dtypes as dtype
44import arrayfire_wrapper .lib as wrapper
55from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
6+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types , get_real_types
67
7- dtype_map = {
8- "int16" : dtype .s16 ,
9- "int32" : dtype .s32 ,
10- "int64" : dtype .s64 ,
11- "uint8" : dtype .u8 ,
12- "uint16" : dtype .u16 ,
13- "uint32" : dtype .u32 ,
14- "uint64" : dtype .u64 ,
15- # 'float16': dtype.f16,
16- # 'float32': dtype.f32,
17- # 'float64': dtype.f64,
18- # 'complex64': dtype.c64,
19- # 'complex32': dtype.c32,
20- "bool" : dtype .b8 ,
21- "s16" : dtype .s16 ,
22- "s32" : dtype .s32 ,
23- "s64" : dtype .s64 ,
24- "u8" : dtype .u8 ,
25- "u16" : dtype .u16 ,
26- "u32" : dtype .u32 ,
27- "u64" : dtype .u64 ,
28- # 'f16': dtype.f16,
29- # 'f32': dtype.f32,
30- # 'f64': dtype.f64,
31- # 'c32': dtype.c32,
32- # 'c64': dtype.c64,
33- "b8" : dtype .b8 ,
34- }
35-
36-
37- @pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
8+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
389def test_bitshiftl_dtypes (dtype_name : dtype .Dtype ) -> None :
3910 """Test bit shift operation across all supported data types."""
11+ check_type_supported (dtype_name )
12+ if dtype_name == dtype .f16 or dtype_name == dtype .f32 :
13+ pytest .skip ()
4014 shape = (5 , 5 )
4115 values = wrapper .randu (shape , dtype_name )
4216 bits_to_shift = wrapper .constant (1 , shape , dtype_name )
@@ -49,7 +23,7 @@ def test_bitshiftl_dtypes(dtype_name: dtype.Dtype) -> None:
4923@pytest .mark .parametrize (
5024 "invdtypes" ,
5125 [
52- dtype .c64 ,
26+ dtype .c32 ,
5327 dtype .f64 ,
5428 ],
5529)
@@ -139,9 +113,12 @@ def test_bitshift_right_varying_shift_amount(shift_amount: int) -> None:
139113 assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
140114
141115
142- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
143117def test_bitshiftr_dtypes (dtype_name : dtype .Dtype ) -> None :
144118 """Test bit shift operation across all supported data types."""
119+ check_type_supported (dtype_name )
120+ if dtype_name == dtype .f16 or dtype_name == dtype .f32 :
121+ pytest .skip ()
145122 shape = (5 , 5 )
146123 values = wrapper .randu (shape , dtype_name )
147124 bits_to_shift = wrapper .constant (1 , shape , dtype_name )
@@ -154,7 +131,7 @@ def test_bitshiftr_dtypes(dtype_name: dtype.Dtype) -> None:
154131@pytest .mark .parametrize (
155132 "invdtypes" ,
156133 [
157- dtype .c64 ,
134+ dtype .c32 ,
158135 dtype .f64 ,
159136 ],
160137)
0 commit comments