1+ import random
2+
13import numpy as np
24import pytest
35
46import arrayfire_wrapper .dtypes as dtype
57import arrayfire_wrapper .lib as wrapper
6- import arrayfire_wrapper .lib .mathematical_functions as ops
7- from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
88
9-
10- import random
9+ # import arrayfire_wrapper.lib.mathematical_functions as ops
10+ from arrayfire_wrapper . lib . create_and_modify_array . helper_functions import array_to_string
1111
1212dtype_map = {
13- ' int16' : dtype .s16 ,
14- ' int32' : dtype .s32 ,
15- ' int64' : dtype .s64 ,
16- ' uint8' : dtype .u8 ,
17- ' uint16' : dtype .u16 ,
18- ' uint32' : dtype .u32 ,
19- ' uint64' : dtype .u64 ,
20- ' float16' : dtype .f16 ,
21- ' float32' : dtype .f32 ,
13+ " int16" : dtype .s16 ,
14+ " int32" : dtype .s32 ,
15+ " int64" : dtype .s64 ,
16+ " uint8" : dtype .u8 ,
17+ " uint16" : dtype .u16 ,
18+ " uint32" : dtype .u32 ,
19+ " uint64" : dtype .u64 ,
20+ " float16" : dtype .f16 ,
21+ " float32" : dtype .f32 ,
2222 # 'float64': dtype.f64,
2323 # 'complex64': dtype.c64,
24- ' complex32' : dtype .c32 ,
25- ' bool' : dtype .b8 ,
26- ' s16' : dtype .s16 ,
27- ' s32' : dtype .s32 ,
28- ' s64' : dtype .s64 ,
29- 'u8' : dtype .u8 ,
30- ' u16' : dtype .u16 ,
31- ' u32' : dtype .u32 ,
32- ' u64' : dtype .u64 ,
33- ' f16' : dtype .f16 ,
34- ' f32' : dtype .f32 ,
24+ " complex32" : dtype .c32 ,
25+ " bool" : dtype .b8 ,
26+ " s16" : dtype .s16 ,
27+ " s32" : dtype .s32 ,
28+ " s64" : dtype .s64 ,
29+ "u8" : dtype .u8 ,
30+ " u16" : dtype .u16 ,
31+ " u32" : dtype .u32 ,
32+ " u64" : dtype .u64 ,
33+ " f16" : dtype .f16 ,
34+ " f32" : dtype .f32 ,
3535 # 'f64': dtype.f64,
36- ' c32' : dtype .c32 ,
36+ " c32" : dtype .c32 ,
3737 # 'c64': dtype.c64,
38- 'b8' : dtype .b8 ,
38+ "b8" : dtype .b8 ,
3939}
4040
41+
4142@pytest .mark .parametrize (
4243 "shape" ,
4344 [
4445 (),
45- (random .randint (1 , 10 ), ),
46+ (random .randint (1 , 10 ),),
4647 (random .randint (1 , 10 ), random .randint (1 , 10 )),
4748 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4849 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -55,7 +56,8 @@ def test_multiply_shapes(shape: tuple) -> None:
5556
5657 result = wrapper .mul (lhs , rhs )
5758
58- assert wrapper .get_dims (result )[0 : len (shape )] == shape
59+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
60+
5961
6062def test_multiply_different_shapes () -> None :
6163 """Test if multiplication handles arrays of different shapes"""
@@ -66,8 +68,10 @@ def test_multiply_different_shapes() -> None:
6668 lhs = wrapper .randu (lhs_shape , dtypes )
6769 rhs = wrapper .randu (rhs_shape , dtypes )
6870 result = wrapper .mul (lhs , rhs )
69- expected_shape = np .broadcast (np .empty (lhs ), np .empty (rhs )).shape
70- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { lhs_shape } and { rhs_shape } "
71+ assert (
72+ wrapper .get_dims (result )[0 : len (lhs_shape )] == lhs_shape # noqa
73+ ), f"Failed for shapes { lhs_shape } and { rhs_shape } "
74+
7175
7276def test_multiply_negative_shapes () -> None :
7377 """Test if multiplication handles arrays of negative shapes"""
@@ -78,18 +82,21 @@ def test_multiply_negative_shapes() -> None:
7882 lhs = wrapper .randu (lhs_shape , dtypes )
7983 rhs = wrapper .randu (rhs_shape , dtypes )
8084 result = wrapper .mul (lhs , rhs )
81- expected_shape = np .broadcast (np .empty (lhs ), np .empty (rhs )).shape
82- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { lhs_shape } and { rhs_shape } "
85+ assert (
86+ wrapper .get_dims (result )[0 : len (lhs_shape )] == lhs_shape # noqa
87+ ), f"Failed for shapes { lhs_shape } and { rhs_shape } "
88+
8389
8490@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
85- def test_multiply_supported_dtypes (dtype_name : str ) -> None :
91+ def test_multiply_supported_dtypes (dtype_name : dtype . Dtype ) -> None :
8692 """Test multiplication operation across all supported data types."""
8793 shape = (5 , 5 )
8894 lhs = wrapper .randu (shape , dtype_name )
8995 rhs = wrapper .randu (shape , dtype_name )
9096 result = wrapper .mul (lhs , rhs )
9197 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
9298
99+
93100@pytest .mark .parametrize (
94101 "invdtypes" ,
95102 [
@@ -105,6 +112,7 @@ def test_multiply_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
105112 rhs = wrapper .randu (shape , invdtypes )
106113 wrapper .mul (lhs , rhs )
107114
115+
108116def test_multiply_zero_sized_arrays () -> None :
109117 """Test multiplication with arrays where at least one array has zero size."""
110118 with pytest .raises (RuntimeError ):
@@ -115,10 +123,11 @@ def test_multiply_zero_sized_arrays() -> None:
115123
116124 result_rhs_zero = wrapper .mul (normal_array , zero_array )
117125 assert wrapper .get_dims (result_rhs_zero ) == normal_shape
118-
126+
119127 result_lhs_zero = wrapper .mul (zero_array , normal_array )
120128 assert wrapper .get_dims (result_lhs_zero ) == zero_shape
121129
130+
122131@pytest .mark .parametrize (
123132 "shape_a, shape_b" ,
124133 [
@@ -136,13 +145,16 @@ def test_multiply_varying_dimensionality(shape_a: tuple, shape_b: tuple) -> None
136145
137146 result = wrapper .mul (lhs , rhs )
138147 expected_shape = np .broadcast (np .empty (shape_a ), np .empty (shape_b )).shape
139- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { shape_a } and { shape_b } "
148+ assert (
149+ wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape # noqa
150+ ), f"Failed for shapes { shape_a } and { shape_b } "
151+
140152
141153@pytest .mark .parametrize (
142154 "shape" ,
143155 [
144156 (),
145- (random .randint (1 , 10 ), ),
157+ (random .randint (1 , 10 ),),
146158 (random .randint (1 , 10 ), random .randint (1 , 10 )),
147159 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
148160 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -157,7 +169,8 @@ def test_divide_shapes(shape: tuple) -> None:
157169
158170 result = wrapper .div (lhs , rhs )
159171
160- assert wrapper .get_dims (result )[0 : len (shape )] == shape
172+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
173+
161174
162175def test_divide_different_shapes () -> None :
163176 """Test if division handles arrays of different shapes"""
@@ -169,7 +182,10 @@ def test_divide_different_shapes() -> None:
169182 rhs = wrapper .randu (rhs_shape , dtypes )
170183 result = wrapper .div (lhs , rhs )
171184 expected_shape = np .broadcast (np .empty (lhs_shape ), np .empty (rhs_shape )).shape
172- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { lhs_shape } and { rhs_shape } "
185+ assert (
186+ wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape # noqa
187+ ), f"Failed for shapes { lhs_shape } and { rhs_shape } "
188+
173189
174190def test_divide_negative_shapes () -> None :
175191 """Test if division handles arrays of negative shapes"""
@@ -181,10 +197,13 @@ def test_divide_negative_shapes() -> None:
181197 rhs = wrapper .randu (rhs_shape , dtypes )
182198 result = wrapper .div (lhs , rhs )
183199 expected_shape = np .broadcast (np .empty (lhs_shape ), np .empty (rhs_shape )).shape
184- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { lhs_shape } and { rhs_shape } "
200+ assert (
201+ wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape # noqa
202+ ), f"Failed for shapes { lhs_shape } and { rhs_shape } "
203+
185204
186205@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
187- def test_divide_supported_dtypes (dtype_name : str ) -> None :
206+ def test_divide_supported_dtypes (dtype_name : dtype . Dtype ) -> None :
188207 """Test division operation across all supported data types."""
189208 shape = (5 , 5 )
190209 lhs = wrapper .randu (shape , dtype_name )
@@ -195,6 +214,7 @@ def test_divide_supported_dtypes(dtype_name: str) -> None:
195214 result = wrapper .div (lhs , rhs )
196215 assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
197216
217+
198218def test_divide_by0 () -> None :
199219 """Test division operation for undefined error type."""
200220 shape = (2 , 2 )
@@ -208,10 +228,11 @@ def test_divide_by0() -> None:
208228 divOut = wrapper .div (lhs , rhs )
209229 print (array_to_string ("" , divOut , 3 , False ))
210230 wrapper .div (lhs , rhs )
211-
231+
212232 # result = wrapper.div(lhs, rhs)
213233 # assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
214234
235+
215236@pytest .mark .parametrize (
216237 "invdtypes" ,
217238 [
@@ -230,6 +251,7 @@ def test_divide_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
230251
231252 wrapper .div (lhs , rhs )
232253
254+
233255def test_divide_zero_sized_arrays () -> None :
234256 """Test division with arrays where at least one array has zero size."""
235257 with pytest .raises (RuntimeError ):
@@ -240,10 +262,11 @@ def test_divide_zero_sized_arrays() -> None:
240262
241263 result_rhs_zero = wrapper .div (normal_array , zero_array )
242264 assert wrapper .get_dims (result_rhs_zero ) == normal_shape
243-
265+
244266 result_lhs_zero = wrapper .div (zero_array , normal_array )
245267 assert wrapper .get_dims (result_lhs_zero ) == zero_shape
246268
269+
247270@pytest .mark .parametrize (
248271 "shape_a, shape_b" ,
249272 [
@@ -263,4 +286,6 @@ def test_divide_varying_dimensionality(shape_a: tuple, shape_b: tuple) -> None:
263286
264287 result = wrapper .div (lhs , rhs )
265288 expected_shape = np .broadcast (np .empty (shape_a ), np .empty (shape_b )).shape
266- assert wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape , f"Failed for shapes { shape_a } and { shape_b } "
289+ assert (
290+ wrapper .get_dims (result )[0 : len (expected_shape )] == expected_shape # noqa
291+ ), f"Failed for shapes { shape_a } and { shape_b } "
0 commit comments