|
7 | 7 | import arrayfire_wrapper.lib as wrapper |
8 | 8 |
|
9 | 9 | # import arrayfire_wrapper.lib.mathematical_functions as ops |
10 | | -from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string |
11 | | - |
12 | | -dtype_map = { |
13 | | - "int16": dtype.s16, |
14 | | - "int32": dtype.s32, |
15 | | - "int64": dtype.s64, |
16 | | - "uint8": dtype.u8, |
17 | | - "uint16": dtype.u16, |
18 | | - "uint32": dtype.u32, |
19 | | - "uint64": dtype.u64, |
20 | | - "float16": dtype.f16, |
21 | | - "float32": dtype.f32, |
22 | | - # 'float64': dtype.f64, |
23 | | - # 'complex64': dtype.c64, |
24 | | - "complex32": dtype.c32, |
25 | | - "bool": dtype.b8, |
26 | | - "s16": dtype.s16, |
27 | | - "s32": dtype.s32, |
28 | | - "s64": dtype.s64, |
29 | | - "u8": dtype.u8, |
30 | | - "u16": dtype.u16, |
31 | | - "u32": dtype.u32, |
32 | | - "u64": dtype.u64, |
33 | | - "f16": dtype.f16, |
34 | | - "f32": dtype.f32, |
35 | | - # 'f64': dtype.f64, |
36 | | - "c32": dtype.c32, |
37 | | - # 'c64': dtype.c64, |
38 | | - "b8": dtype.b8, |
39 | | -} |
| 10 | + |
| 11 | +from . import utility_functions as util |
40 | 12 |
|
41 | 13 |
|
42 | 14 | @pytest.mark.parametrize( |
@@ -87,9 +59,10 @@ def test_multiply_negative_shapes() -> None: |
87 | 59 | ), f"Failed for shapes {lhs_shape} and {rhs_shape}" |
88 | 60 |
|
89 | 61 |
|
90 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 62 | +@pytest.mark.parametrize("dtype_name", util.get_all_types()) |
91 | 63 | def test_multiply_supported_dtypes(dtype_name: dtype.Dtype) -> None: |
92 | 64 | """Test multiplication operation across all supported data types.""" |
| 65 | + util.check_type_supported(dtype_name) |
93 | 66 | shape = (5, 5) |
94 | 67 | lhs = wrapper.randu(shape, dtype_name) |
95 | 68 | rhs = wrapper.randu(shape, dtype_name) |
@@ -201,9 +174,10 @@ def test_divide_negative_shapes() -> None: |
201 | 174 | ), f"Failed for shapes {lhs_shape} and {rhs_shape}" |
202 | 175 |
|
203 | 176 |
|
204 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 177 | +@pytest.mark.parametrize("dtype_name", util.get_all_types()) |
205 | 178 | def test_divide_supported_dtypes(dtype_name: dtype.Dtype) -> None: |
206 | 179 | """Test division operation across all supported data types.""" |
| 180 | + util.check_type_supported(dtype_name) |
207 | 181 | shape = (5, 5) |
208 | 182 | lhs = wrapper.randu(shape, dtype_name) |
209 | 183 | rhs = wrapper.randu(shape, dtype_name) |
|
0 commit comments