11import random
22
3- # import numpy as np
43import pytest
54
65import arrayfire_wrapper .dtypes as dtype
76import arrayfire_wrapper .lib as wrapper
8-
9- dtype_map = {
10- # "int16": dtype.s16,
11- # "int32": dtype.s32,
12- # "int64": dtype.s64,
13- # "uint8": dtype.u8,
14- # "uint16": dtype.u16,
15- # "uint32": dtype.u32,
16- # "uint64": dtype.u64,
17- # "float16": dtype.f16,
18- "float32" : dtype .f32 ,
19- # 'float64': dtype.f64,
20- # 'complex64': dtype.c64,
21- # "complex32": dtype.c32,
22- # "bool": dtype.b8,
23- # "s16": dtype.s16,
24- # "s32": dtype.s32,
25- # "s64": dtype.s64,
26- # "u8": dtype.u8,
27- # "u16": dtype.u16,
28- # "u32": dtype.u32,
29- # "u64": dtype.u64,
30- # "f16": dtype.f16,
31- "f32" : dtype .f32 ,
32- # 'f64': dtype.f64,
33- # "c32": dtype.c32,
34- # 'c64': dtype.c64,
35- # "b8": dtype.b8,
36- }
7+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types , get_real_types
378
389
3910@pytest .mark .parametrize (
4617 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4718 ],
4819)
49- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20+ @pytest .mark .parametrize ("dtype_name" , get_float_types ())
5021def test_complex_supported_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
5122 """Test complex operation across all supported data types."""
23+ check_type_supported (dtype_name )
24+ if dtype_name == dtype .f16 :
25+ pytest .skip ()
5226 tester = wrapper .randu (shape , dtype_name )
5327 result = wrapper .cplx (tester )
5428 assert wrapper .is_complex (result ), f"Failed for dtype: { dtype_name } "
@@ -57,8 +31,8 @@ def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None
5731@pytest .mark .parametrize (
5832 "invdtypes" ,
5933 [
60- dtype .c64 ,
61- dtype .f64 ,
34+ dtype .int32 ,
35+ dtype .complex32 ,
6236 ],
6337)
6438def test_complex_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
@@ -79,9 +53,10 @@ def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7953 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
8054 ],
8155)
82- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
56+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
8357def test_complex2_supported_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
8458 """Test complex2 operation across all supported data types."""
59+ check_type_supported (dtype_name )
8560 lhs = wrapper .randu (shape , dtype_name )
8661 rhs = wrapper .randu (shape , dtype_name )
8762 result = wrapper .cplx2 (lhs , rhs )
@@ -91,8 +66,7 @@ def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> Non
9166@pytest .mark .parametrize (
9267 "invdtypes" ,
9368 [
94- dtype .c64 ,
95- dtype .f64 ,
69+ dtype .c32 ,
9670 ],
9771)
9872def test_complex2_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
@@ -114,26 +88,13 @@ def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
11488 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
11589 ],
11690)
117- def test_conj_supported_dtypes (shape : tuple ) -> None :
91+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
92+ def test_conj_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
11893 """Test conjugate operation for supported data types."""
119- arr = wrapper .constant (7 , shape , dtype .c32 )
94+ check_type_supported (dtypes )
95+ arr = wrapper .constant (7 , shape , dtypes )
12096 result = wrapper .conjg (arr )
121- assert wrapper .is_complex (result ), f"Failed for shape: { shape } "
122-
123-
124- @pytest .mark .parametrize (
125- "invdtypes" ,
126- [
127- dtype .c64 ,
128- dtype .f64 ,
129- ],
130- )
131- def test_conj_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
132- """Test conjugate operation for unsupported data types."""
133- with pytest .raises (RuntimeError ):
134- shape = (5 , 5 )
135- arr = wrapper .randu (shape , invdtypes )
136- wrapper .conjg (arr )
97+ assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"Failed for shape: { shape } , and dtype: { dtypes } " # noqa
13798
13899
139100@pytest .mark .parametrize (
@@ -146,40 +107,29 @@ def test_conj_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
146107 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
147108 ],
148109)
149- def test_imag_real_supported_dtypes (shape : tuple ) -> None :
110+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
111+ def test_imag_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
150112 """Test imaginary and real operations for supported data types."""
151- arr = wrapper . randu ( shape , dtype . c32 )
152- imaginary = wrapper .imag ( arr )
113+ check_type_supported ( dtypes )
114+ arr = wrapper .randu ( shape , dtypes )
153115 real = wrapper .real (arr )
154- assert not wrapper .is_empty (imaginary ), f"Failed for shape: { shape } "
155- assert not wrapper .is_empty (real ), f"Failed for shape: { shape } "
156-
157-
158- @pytest .mark .parametrize (
159- "invdtypes" ,
160- [
161- dtype .c64 ,
162- dtype .f64 ,
163- ],
164- )
165- def test_imag_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
166- """Test conjugate operation for unsupported data types."""
167- with pytest .raises (RuntimeError ):
168- shape = (5 , 5 )
169- arr = wrapper .randu (shape , invdtypes )
170- wrapper .imag (arr )
116+ assert wrapper .is_real (real ), f"Failed for shape: { shape } "
171117
172118
173119@pytest .mark .parametrize (
174- "invdtypes " ,
120+ "shape " ,
175121 [
176- dtype .c64 ,
177- dtype .f64 ,
122+ (),
123+ (random .randint (1 , 10 ),),
124+ (random .randint (1 , 10 ), random .randint (1 , 10 )),
125+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
126+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
178127 ],
179128)
180- def test_real_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
181- """Test real operation for unsupported data types."""
182- with pytest .raises (RuntimeError ):
183- shape = (5 , 5 )
184- arr = wrapper .randu (shape , invdtypes )
185- wrapper .real (arr )
129+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
130+ def test_real_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
131+ """Test imaginary and real operations for supported data types."""
132+ check_type_supported (dtypes )
133+ arr = wrapper .randu (shape , dtypes )
134+ real = wrapper .real (arr )
135+ assert wrapper .is_real (real ), f"Failed for shape: { shape } "
0 commit comments