55from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
66
77dtype_map = {
8- ' int16' : dtype .s16 ,
9- ' int32' : dtype .s32 ,
10- ' int64' : dtype .s64 ,
11- ' uint8' : dtype .u8 ,
12- ' uint16' : dtype .u16 ,
13- ' uint32' : dtype .u32 ,
14- ' uint64' : dtype .u64 ,
8+ " int16" : dtype .s16 ,
9+ " int32" : dtype .s32 ,
10+ " int64" : dtype .s64 ,
11+ " uint8" : dtype .u8 ,
12+ " uint16" : dtype .u16 ,
13+ " uint32" : dtype .u32 ,
14+ " uint64" : dtype .u64 ,
1515 # 'float16': dtype.f16,
1616 # 'float32': dtype.f32,
1717 # 'float64': dtype.f64,
1818 # 'complex64': dtype.c64,
1919 # 'complex32': dtype.c32,
20- ' bool' : dtype .b8 ,
21- ' s16' : dtype .s16 ,
22- ' s32' : dtype .s32 ,
23- ' s64' : dtype .s64 ,
24- 'u8' : dtype .u8 ,
25- ' u16' : dtype .u16 ,
26- ' u32' : dtype .u32 ,
27- ' u64' : dtype .u64 ,
20+ " bool" : dtype .b8 ,
21+ " s16" : dtype .s16 ,
22+ " s32" : dtype .s32 ,
23+ " s64" : dtype .s64 ,
24+ "u8" : dtype .u8 ,
25+ " u16" : dtype .u16 ,
26+ " u32" : dtype .u32 ,
27+ " u64" : dtype .u64 ,
2828 # 'f16': dtype.f16,
2929 # 'f32': dtype.f32,
3030 # 'f64': dtype.f64,
3131 # 'c32': dtype.c32,
3232 # 'c64': dtype.c64,
33- 'b8' : dtype .b8 ,
33+ "b8" : dtype .b8 ,
3434}
35+
36+
3537@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
3638def test_bitshiftl_dtypes (dtype_name : dtype .Dtype ) -> None :
3739 """Test bit shift operation across all supported data types."""
3840 shape = (5 , 5 )
3941 values = wrapper .randu (shape , dtype_name )
4042 bits_to_shift = wrapper .constant (1 , shape , dtype_name )
41-
43+
4244 result = wrapper .bitshiftl (values , bits_to_shift )
4345
4446 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
47+
48+
4549@pytest .mark .parametrize (
4650 "invdtypes" ,
4751 [
@@ -54,25 +58,28 @@ def test_bitshiftl_supported_dtypes(invdtypes: dtype.Dtype) -> None:
5458 shape = (5 , 5 )
5559 with pytest .raises (RuntimeError ):
5660 value = wrapper .randu (shape , invdtypes )
57- shift_amount = 1
61+ bits_to_shift = wrapper . constant ( 1 , shape , invdtypes )
5862
59- result = wrapper .bitshiftl (value , shift_amount )
63+ result = wrapper .bitshiftl (value , bits_to_shift )
6064 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
65+
66+
6167@pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
62- def test_bitshiftl_varying_input_size (input_size ) :
68+ def test_bitshiftl_varying_input_size (input_size : int ) -> None :
6369 """Test bitshift left operation with varying input sizes"""
6470 shape = (input_size , input_size )
6571 value = wrapper .randu (shape , dtype .int16 )
6672 shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
6773
6874 result = wrapper .bitshiftl (value , shift_amount )
6975
70- assert wrapper .get_dims (result )[0 : len (shape )] == shape
76+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
77+
7178
7279@pytest .mark .parametrize (
7380 "shape" ,
7481 [
75- (10 , ),
82+ (10 ,),
7683 (5 , 5 ),
7784 (2 , 3 , 4 ),
7885 ],
@@ -81,21 +88,23 @@ def test_bitshiftl_varying_shapes(shape: tuple) -> None:
8188 """Test left bit shifting with arrays of varying shapes."""
8289 values = wrapper .randu (shape , dtype .int16 )
8390 bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
84-
91+
8592 result = wrapper .bitshiftl (values , bits_to_shift )
8693
87- assert wrapper .get_dims (result )[0 : len (shape )] == shape
94+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
95+
8896
8997@pytest .mark .parametrize ("shift_amount" , [- 1 , 0 , 2 , 30 ])
90- def test_bitshift_left_varying_shift_amount (shift_amount ) :
98+ def test_bitshift_left_varying_shift_amount (shift_amount : int ) -> None :
9199 """Test bitshift left operation with varying shift amounts."""
92100 shape = (5 , 5 )
93101 value = wrapper .randu (shape , dtype .int16 )
94102 shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
95103
96104 result = wrapper .bitshiftl (value , shift_amount_arr )
97105
98- assert wrapper .get_dims (result )[0 : len (shape )] == shape
106+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
107+
99108
100109@pytest .mark .parametrize (
101110 "shape_a, shape_b" ,
@@ -113,29 +122,35 @@ def test_bitshiftl_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
113122 bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
114123 result = wrapper .bitshiftl (values , bits_to_shift )
115124 print (array_to_string ("" , result , 3 , False ))
116- assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
125+ assert (
126+ wrapper .get_dims (result )[0 : len (shape_a )] == shape_a # noqa
127+ ), f"Failed for shapes { shape_a } and { shape_b } "
128+
117129
118130@pytest .mark .parametrize ("shift_amount" , [- 1 , 0 , 2 , 30 ])
119- def test_bitshift_right_varying_shift_amount (shift_amount ) :
131+ def test_bitshift_right_varying_shift_amount (shift_amount : int ) -> None :
120132 """Test bitshift right operation with varying shift amounts."""
121133 shape = (5 , 5 )
122134 value = wrapper .randu (shape , dtype .int16 )
123135 shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
124136
125137 result = wrapper .bitshiftr (value , shift_amount_arr )
126138
127- assert wrapper .get_dims (result )[0 : len (shape )] == shape
139+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
140+
128141
129142@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
130143def test_bitshiftr_dtypes (dtype_name : dtype .Dtype ) -> None :
131144 """Test bit shift operation across all supported data types."""
132145 shape = (5 , 5 )
133146 values = wrapper .randu (shape , dtype_name )
134147 bits_to_shift = wrapper .constant (1 , shape , dtype_name )
135-
148+
136149 result = wrapper .bitshiftr (values , bits_to_shift )
137150
138151 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
152+
153+
139154@pytest .mark .parametrize (
140155 "invdtypes" ,
141156 [
@@ -148,25 +163,28 @@ def test_bitshiftr_supported_dtypes(invdtypes: dtype.Dtype) -> None:
148163 shape = (5 , 5 )
149164 with pytest .raises (RuntimeError ):
150165 value = wrapper .randu (shape , invdtypes )
151- shift_amount = 1
166+ shift_amount = wrapper . constant ( 1 , shape , invdtypes )
152167
153168 result = wrapper .bitshiftr (value , shift_amount )
154169 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
155170
171+
156172@pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
157- def test_bitshift_right_varying_input_size (input_size ) :
173+ def test_bitshift_right_varying_input_size (input_size : int ) -> None :
158174 """Test bitshift right operation with varying input sizes"""
159175 shape = (input_size , input_size )
160176 value = wrapper .randu (shape , dtype .int16 )
161177 shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
162178
163179 result = wrapper .bitshiftr (value , shift_amount )
164180
165- assert wrapper .get_dims (result )[0 : len (shape )] == shape
181+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
182+
183+
166184@pytest .mark .parametrize (
167185 "shape" ,
168186 [
169- (10 , ),
187+ (10 ,),
170188 (5 , 5 ),
171189 (2 , 3 , 4 ),
172190 ],
@@ -175,11 +193,10 @@ def test_bitshiftr_varying_shapes(shape: tuple) -> None:
175193 """Test right bit shifting with arrays of varying shapes."""
176194 values = wrapper .randu (shape , dtype .int16 )
177195 bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
178-
179- result = wrapper .bitshiftr (values , bits_to_shift )
180196
181- assert wrapper .get_dims ( result )[ 0 : len ( shape )] == shape
197+ result = wrapper .bitshiftr ( values , bits_to_shift )
182198
199+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
183200
184201
185202@pytest .mark .parametrize (
@@ -198,4 +215,6 @@ def test_bitshiftr_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
198215 bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
199216 result = wrapper .bitshiftr (values , bits_to_shift )
200217 print (array_to_string ("" , result , 3 , False ))
201- assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
218+ assert (
219+ wrapper .get_dims (result )[0 : len (shape_a )] == shape_a # noqa
220+ ), f"Failed for shapes { shape_a } and { shape_b } "
0 commit comments