44
55import arrayfire_wrapper .dtypes as dtype
66import arrayfire_wrapper .lib as wrapper
7-
8- dtype_map = {
9- "int16" : dtype .s16 ,
10- "int32" : dtype .s32 ,
11- "int64" : dtype .s64 ,
12- "uint8" : dtype .u8 ,
13- "uint16" : dtype .u16 ,
14- "uint32" : dtype .u32 ,
15- "uint64" : dtype .u64 ,
16- "float16" : dtype .f16 ,
17- "float32" : dtype .f32 ,
18- # 'float64': dtype.f64,
19- # 'complex64': dtype.c64,
20- # 'complex32': dtype.c32,
21- "bool" : dtype .b8 ,
22- "s16" : dtype .s16 ,
23- "s32" : dtype .s32 ,
24- "s64" : dtype .s64 ,
25- "u8" : dtype .u8 ,
26- "u16" : dtype .u16 ,
27- "u32" : dtype .u32 ,
28- "u64" : dtype .u64 ,
29- "f16" : dtype .f16 ,
30- "f32" : dtype .f32 ,
31- # 'f64': dtype.f64,
32- # 'c32': dtype.c32,
33- # 'c64': dtype.c64,
34- "b8" : dtype .b8 ,
35- }
7+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types
368
379
3810@pytest .mark .parametrize (
3911 "shape" ,
4012 [
4113 (),
4214 (random .randint (1 , 10 ),),
43- (random .randint (1 , 10 ),),
44- (random .randint (1 , 10 ),),
4515 (random .randint (1 , 10 ), random .randint (1 , 10 )),
4616 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4717 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4818 ],
4919)
50- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
5121def test_asinh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
5222 """Test inverse hyperbolic sine operation across all supported data types."""
23+ check_type_supported (dtype_name )
5324 values = wrapper .randu (shape , dtype_name )
5425 result = wrapper .asinh (values )
5526 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -73,16 +44,15 @@ def test_asinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7344 [
7445 (),
7546 (random .randint (1 , 10 ),),
76- (random .randint (1 , 10 ),),
77- (random .randint (1 , 10 ),),
7847 (random .randint (1 , 10 ), random .randint (1 , 10 )),
7948 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
8049 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
8150 ],
8251)
83- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
52+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
8453def test_acosh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
8554 """Test inverse hyperbolic cosine operation across all supported data types."""
55+ check_type_supported (dtype_name )
8656 values = wrapper .randu (shape , dtype_name )
8757 result = wrapper .acosh (values )
8858 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -106,16 +76,15 @@ def test_acosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
10676 [
10777 (),
10878 (random .randint (1 , 10 ),),
109- (random .randint (1 , 10 ),),
110- (random .randint (1 , 10 ),),
11179 (random .randint (1 , 10 ), random .randint (1 , 10 )),
11280 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
11381 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
11482 ],
11583)
116- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
84+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
11785def test_atanh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
11886 """Test inverse hyperbolic tan operation across all supported data types."""
87+ check_type_supported (dtype_name )
11988 values = wrapper .randu (shape , dtype_name )
12089 result = wrapper .atanh (values )
12190 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -139,16 +108,15 @@ def test_atanh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
139108 [
140109 (),
141110 (random .randint (1 , 10 ),),
142- (random .randint (1 , 10 ),),
143- (random .randint (1 , 10 ),),
144111 (random .randint (1 , 10 ), random .randint (1 , 10 )),
145112 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
146113 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
147114 ],
148115)
149- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
150117def test_cosh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
151118 """Test hyperbolic cosine operation across all supported data types."""
119+ check_type_supported (dtype_name )
152120 values = wrapper .randu (shape , dtype_name )
153121 result = wrapper .cosh (values )
154122 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -172,16 +140,15 @@ def test_cosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
172140 [
173141 (),
174142 (random .randint (1 , 10 ),),
175- (random .randint (1 , 10 ),),
176- (random .randint (1 , 10 ),),
177143 (random .randint (1 , 10 ), random .randint (1 , 10 )),
178144 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
179145 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
180146 ],
181147)
182- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
148+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
183149def test_sinh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
184150 """Test hyberbolic sin operation across all supported data types."""
151+ check_type_supported (dtype_name )
185152 values = wrapper .randu (shape , dtype_name )
186153 result = wrapper .sinh (values )
187154 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -205,16 +172,15 @@ def test_sinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
205172 [
206173 (),
207174 (random .randint (1 , 10 ),),
208- (random .randint (1 , 10 ),),
209- (random .randint (1 , 10 ),),
210175 (random .randint (1 , 10 ), random .randint (1 , 10 )),
211176 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
212177 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
213178 ],
214179)
215- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
180+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
216181def test_tanh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
217182 """Test hyberbolic tan operation across all supported data types."""
183+ check_type_supported (dtype_name )
218184 values = wrapper .randu (shape , dtype_name )
219185 result = wrapper .tanh (values )
220186 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
0 commit comments