1+ import pytest
2+
3+ import arrayfire_wrapper .dtypes as dtype
4+ import arrayfire_wrapper .lib as wrapper
5+ from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
6+
7+ dtype_map = {
8+ 'int16' : dtype .s16 ,
9+ 'int32' : dtype .s32 ,
10+ 'int64' : dtype .s64 ,
11+ 'uint8' : dtype .u8 ,
12+ 'uint16' : dtype .u16 ,
13+ 'uint32' : dtype .u32 ,
14+ 'uint64' : dtype .u64 ,
15+ # 'float16': dtype.f16,
16+ # 'float32': dtype.f32,
17+ # 'float64': dtype.f64,
18+ # 'complex64': dtype.c64,
19+ # 'complex32': dtype.c32,
20+ 'bool' : dtype .b8 ,
21+ 's16' : dtype .s16 ,
22+ 's32' : dtype .s32 ,
23+ 's64' : dtype .s64 ,
24+ 'u8' : dtype .u8 ,
25+ 'u16' : dtype .u16 ,
26+ 'u32' : dtype .u32 ,
27+ 'u64' : dtype .u64 ,
28+ # 'f16': dtype.f16,
29+ # 'f32': dtype.f32,
30+ # 'f64': dtype.f64,
31+ # 'c32': dtype.c32,
32+ # 'c64': dtype.c64,
33+ 'b8' : dtype .b8 ,
34+ }
35+ @pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
36+ def test_bitshiftl_dtypes (dtype_name : dtype .Dtype ) -> None :
37+ """Test bit shift operation across all supported data types."""
38+ shape = (5 , 5 )
39+ values = wrapper .randu (shape , dtype_name )
40+ bits_to_shift = wrapper .constant (1 , shape , dtype_name )
41+
42+ result = wrapper .bitshiftl (values , bits_to_shift )
43+
44+ assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
45+ @pytest .mark .parametrize (
46+ "invdtypes" ,
47+ [
48+ dtype .c64 ,
49+ dtype .f64 ,
50+ ],
51+ )
52+ def test_bitshiftl_supported_dtypes (invdtypes : dtype .Dtype ) -> None :
53+ """Test bitshift operations for unsupported integer data types."""
54+ shape = (5 , 5 )
55+ with pytest .raises (RuntimeError ):
56+ value = wrapper .randu (shape , invdtypes )
57+ shift_amount = 1
58+
59+ result = wrapper .bitshiftl (value , shift_amount )
60+ assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
61+ @pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
62+ def test_bitshiftl_varying_input_size (input_size ):
63+ """Test bitshift left operation with varying input sizes"""
64+ shape = (input_size , input_size )
65+ value = wrapper .randu (shape , dtype .int16 )
66+ shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
67+
68+ result = wrapper .bitshiftl (value , shift_amount )
69+
70+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
71+
72+ @pytest .mark .parametrize (
73+ "shape" ,
74+ [
75+ (10 , ),
76+ (5 , 5 ),
77+ (2 , 3 , 4 ),
78+ ],
79+ )
80+ def test_bitshiftl_varying_shapes (shape : tuple ) -> None :
81+ """Test left bit shifting with arrays of varying shapes."""
82+ values = wrapper .randu (shape , dtype .int16 )
83+ bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
84+
85+ result = wrapper .bitshiftl (values , bits_to_shift )
86+
87+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
88+
89+ @pytest .mark .parametrize ("shift_amount" , [0 , 2 , 30 ])
90+ def test_bitshift_left_varying_shift_amount (shift_amount ):
91+ """Test bitshift left operation with varying shift amounts."""
92+ shape = (5 , 5 )
93+ value = wrapper .randu (shape , dtype .int16 )
94+ shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
95+
96+ result = wrapper .bitshiftl (value , shift_amount_arr )
97+
98+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
99+
100+ @pytest .mark .parametrize (
101+ "shape_a, shape_b" ,
102+ [
103+ ((1 , 5 ), (5 , 1 )), # 2D with 2D inverse
104+ ((5 , 5 ), (5 , 1 )), # 2D with 2D
105+ ((5 , 5 ), (1 , 1 )), # 2D with 2D
106+ ((1 , 1 , 1 ), (5 , 5 , 5 )), # 3D with 3D
107+ ],
108+ )
109+ def test_bitshiftl_different_shapes (shape_a : tuple , shape_b : tuple ) -> None :
110+ """Test if left bit shifting handles arrays of different shapes"""
111+ with pytest .raises (RuntimeError ):
112+ values = wrapper .randu (shape_a , dtype .int16 )
113+ bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
114+ result = wrapper .bitshiftl (values , bits_to_shift )
115+ print (array_to_string ("" , result , 3 , False ))
116+ assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
117+
118+ @pytest .mark .parametrize ("shift_amount" , [0 , 2 , 30 ])
119+ def test_bitshift_right_varying_shift_amount (shift_amount ):
120+ """Test bitshift right operation with varying shift amounts."""
121+ shape = (5 , 5 )
122+ value = wrapper .randu (shape , dtype .int16 )
123+ shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
124+
125+ result = wrapper .bitshiftr (value , shift_amount_arr )
126+
127+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
128+
129+ @pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
130+ def test_bitshiftr_dtypes (dtype_name : dtype .Dtype ) -> None :
131+ """Test bit shift operation across all supported data types."""
132+ shape = (5 , 5 )
133+ values = wrapper .randu (shape , dtype_name )
134+ bits_to_shift = wrapper .constant (1 , shape , dtype_name )
135+
136+ result = wrapper .bitshiftr (values , bits_to_shift )
137+
138+ assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
139+ @pytest .mark .parametrize (
140+ "invdtypes" ,
141+ [
142+ dtype .c64 ,
143+ dtype .f64 ,
144+ ],
145+ )
146+ def test_bitshiftr_supported_dtypes (invdtypes : dtype .Dtype ) -> None :
147+ """Test bitshift operations for unsupported integer data types."""
148+ shape = (5 , 5 )
149+ with pytest .raises (RuntimeError ):
150+ value = wrapper .randu (shape , invdtypes )
151+ shift_amount = 1
152+
153+ result = wrapper .bitshiftr (value , shift_amount )
154+ assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
155+
156+ @pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
157+ def test_bitshift_right_varying_input_size (input_size ):
158+ """Test bitshift right operation with varying input sizes"""
159+ shape = (input_size , input_size )
160+ value = wrapper .randu (shape , dtype .int16 )
161+ shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
162+
163+ result = wrapper .bitshiftr (value , shift_amount )
164+
165+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
166+ @pytest .mark .parametrize (
167+ "shape" ,
168+ [
169+ (10 , ),
170+ (5 , 5 ),
171+ (2 , 3 , 4 ),
172+ ],
173+ )
174+ def test_bitshiftr_varying_shapes (shape : tuple ) -> None :
175+ """Test right bit shifting with arrays of varying shapes."""
176+ values = wrapper .randu (shape , dtype .int16 )
177+ bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
178+
179+ result = wrapper .bitshiftr (values , bits_to_shift )
180+
181+ assert wrapper .get_dims (result )[0 : len (shape )] == shape
182+
183+
184+
185+ @pytest .mark .parametrize (
186+ "shape_a, shape_b" ,
187+ [
188+ ((1 , 5 ), (5 , 1 )), # 2D with 2D inverse
189+ ((5 , 5 ), (5 , 1 )), # 2D with 2D
190+ ((5 , 5 ), (1 , 1 )), # 2D with 2D
191+ ((1 , 1 , 1 ), (5 , 5 , 5 )), # 3D with 3D
192+ ],
193+ )
194+ def test_bitshiftr_different_shapes (shape_a : tuple , shape_b : tuple ) -> None :
195+ """Test if right bit shifting handles arrays of different shapes"""
196+ with pytest .raises (RuntimeError ):
197+ values = wrapper .randu (shape_a , dtype .int16 )
198+ bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
199+ result = wrapper .bitshiftr (values , bits_to_shift )
200+ print (array_to_string ("" , result , 3 , False ))
201+ assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
0 commit comments