55import arrayfire_wrapper .dtypes as dtype
66import arrayfire_wrapper .lib as wrapper
77
8-
98dtype_map = {
10- 'int16' : dtype .s16 ,
11- 'int32' : dtype .s32 ,
12- 'int64' : dtype .s64 ,
13- 'uint8' : dtype .u8 ,
14- 'uint16' : dtype .u16 ,
15- 'uint32' : dtype .u32 ,
16- 'uint64' : dtype .u64 ,
17- 'float16' : dtype .f16 ,
18- 'float32' : dtype .f32 ,
199 "int16" : dtype .s16 ,
2010 "int32" : dtype .s32 ,
2111 "int64" : dtype .s64 ,
2818 # 'float64': dtype.f64,
2919 # 'complex64': dtype.c64,
3020 # 'complex32': dtype.c32,
31- 'bool' : dtype .b8 ,
32- 's16' : dtype .s16 ,
33- 's32' : dtype .s32 ,
34- 's64' : dtype .s64 ,
35- 'u8' : dtype .u8 ,
36- 'u16' : dtype .u16 ,
37- 'u32' : dtype .u32 ,
38- 'u64' : dtype .u64 ,
39- 'f16' : dtype .f16 ,
40- 'f32' : dtype .f32 ,
4121 "bool" : dtype .b8 ,
4222 "s16" : dtype .s16 ,
4323 "s32" : dtype .s32 ,
5131 # 'f64': dtype.f64,
5232 # 'c32': dtype.c32,
5333 # 'c64': dtype.c64,
54- 'b8' : dtype .b8 ,
5534 "b8" : dtype .b8 ,
5635}
5736
6039 "shape" ,
6140 [
6241 (),
63- (random .randint (1 , 10 ), ),
42+ (random .randint (1 , 10 ),),
6443 (random .randint (1 , 10 ),),
6544 (random .randint (1 , 10 ),),
6645 (random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -87,11 +66,13 @@ def test_asinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
8766 """Test inverse hyperbolic sine operation for unsupported data types."""
8867 with pytest .raises (RuntimeError ):
8968 wrapper .asinh (wrapper .randu ((10 , 10 ), invdtypes ))
69+
70+
9071@pytest .mark .parametrize (
9172 "shape" ,
9273 [
9374 (),
94- (random .randint (1 , 10 ), ),
75+ (random .randint (1 , 10 ),),
9576 (random .randint (1 , 10 ),),
9677 (random .randint (1 , 10 ),),
9778 (random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -118,11 +99,13 @@ def test_acosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
11899 """Test inverse hyperbolic cosine operation for unsupported data types."""
119100 with pytest .raises (RuntimeError ):
120101 wrapper .acosh (wrapper .randu ((10 , 10 ), invdtypes ))
102+
103+
121104@pytest .mark .parametrize (
122105 "shape" ,
123106 [
124107 (),
125- (random .randint (1 , 10 ), ),
108+ (random .randint (1 , 10 ),),
126109 (random .randint (1 , 10 ),),
127110 (random .randint (1 , 10 ),),
128111 (random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -150,11 +133,12 @@ def test_atanh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
150133 with pytest .raises (RuntimeError ):
151134 wrapper .atanh (wrapper .randu ((10 , 10 ), invdtypes ))
152135
136+
153137@pytest .mark .parametrize (
154138 "shape" ,
155139 [
156140 (),
157- (random .randint (1 , 10 ), ),
141+ (random .randint (1 , 10 ),),
158142 (random .randint (1 , 10 ),),
159143 (random .randint (1 , 10 ),),
160144 (random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -182,11 +166,12 @@ def test_cosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
182166 with pytest .raises (RuntimeError ):
183167 wrapper .cosh (wrapper .randu ((10 , 10 ), invdtypes ))
184168
169+
185170@pytest .mark .parametrize (
186171 "shape" ,
187172 [
188173 (),
189- (random .randint (1 , 10 ), ),
174+ (random .randint (1 , 10 ),),
190175 (random .randint (1 , 10 ),),
191176 (random .randint (1 , 10 ),),
192177 (random .randint (1 , 10 ), random .randint (1 , 10 )),
@@ -214,11 +199,12 @@ def test_sinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
214199 with pytest .raises (RuntimeError ):
215200 wrapper .sinh (wrapper .randu ((10 , 10 ), invdtypes ))
216201
202+
217203@pytest .mark .parametrize (
218204 "shape" ,
219205 [
220206 (),
221- (random .randint (1 , 10 ), ),
207+ (random .randint (1 , 10 ),),
222208 (random .randint (1 , 10 ),),
223209 (random .randint (1 , 10 ),),
224210 (random .randint (1 , 10 ), random .randint (1 , 10 )),
0 commit comments