44
55import arrayfire_wrapper .dtypes as dtype
66import arrayfire_wrapper .lib as wrapper
7-
8- dtype_map = {
9- "int16" : dtype .s16 ,
10- "int32" : dtype .s32 ,
11- "int64" : dtype .s64 ,
12- "uint8" : dtype .u8 ,
13- "uint16" : dtype .u16 ,
14- "uint32" : dtype .u32 ,
15- "uint64" : dtype .u64 ,
16- "float16" : dtype .f16 ,
17- "float32" : dtype .f32 ,
18- # 'float64': dtype.f64,
19- # 'complex64': dtype.c64,
20- # 'complex32': dtype.c32,
21- "bool" : dtype .b8 ,
22- "s16" : dtype .s16 ,
23- "s32" : dtype .s32 ,
24- "s64" : dtype .s64 ,
25- "u8" : dtype .u8 ,
26- "u16" : dtype .u16 ,
27- "u32" : dtype .u32 ,
28- "u64" : dtype .u64 ,
29- "f16" : dtype .f16 ,
30- "f32" : dtype .f32 ,
31- # 'f64': dtype.f64,
32- # 'c32': dtype.c32,
33- # 'c64': dtype.c64,
34- "b8" : dtype .b8 ,
35- }
7+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types
368
379
3810@pytest .mark .parametrize (
3911 "shape" ,
4012 [
4113 (),
4214 (random .randint (1 , 10 ),),
43- (random .randint (1 , 10 ),),
4415 (random .randint (1 , 10 ), random .randint (1 , 10 )),
4516 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4617 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
4718 ],
4819)
49- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
5021def test_cbrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
5122 """Test cube root operation across all supported data types."""
23+ check_type_supported (dtype_name )
5224 values = wrapper .randu (shape , dtype_name )
5325 result = wrapper .cbrt (values )
5426 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -72,15 +44,15 @@ def test_cbrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7244 [
7345 (),
7446 (random .randint (1 , 10 ),),
75- (random .randint (1 , 10 ),),
7647 (random .randint (1 , 10 ), random .randint (1 , 10 )),
7748 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
7849 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
7950 ],
8051)
81- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
52+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
8253def test_erf_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
8354 """Test gaussian error operation across all supported data types."""
55+ check_type_supported (dtype_name )
8456 values = wrapper .randu (shape , dtype_name )
8557 result = wrapper .erf (values )
8658 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -104,15 +76,15 @@ def test_erf_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
10476 [
10577 (),
10678 (random .randint (1 , 10 ),),
107- (random .randint (1 , 10 ),),
10879 (random .randint (1 , 10 ), random .randint (1 , 10 )),
10980 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
11081 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
11182 ],
11283)
113- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
84+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
11485def test_erfc_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
11586 """Test gaussian error complement operation across all supported data types."""
87+ check_type_supported (dtype_name )
11688 values = wrapper .randu (shape , dtype_name )
11789 result = wrapper .erfc (values )
11890 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -136,15 +108,15 @@ def test_erfc_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
136108 [
137109 (),
138110 (random .randint (1 , 10 ),),
139- (random .randint (1 , 10 ),),
140111 (random .randint (1 , 10 ), random .randint (1 , 10 )),
141112 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
142113 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
143114 ],
144115)
145- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
146117def test_exp_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
147118 """Test exponent operation across all supported data types."""
119+ check_type_supported (dtype_name )
148120 values = wrapper .randu (shape , dtype_name )
149121 result = wrapper .exp (values )
150122 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -168,15 +140,15 @@ def test_exp_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
168140 [
169141 (),
170142 (random .randint (1 , 10 ),),
171- (random .randint (1 , 10 ),),
172143 (random .randint (1 , 10 ), random .randint (1 , 10 )),
173144 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
174145 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
175146 ],
176147)
177- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
148+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
178149def test_exp1_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
179150 """Test exponent - 1 operation across all supported data types."""
151+ check_type_supported (dtype_name )
180152 values = wrapper .randu (shape , dtype_name )
181153 result = wrapper .expm1 (values )
182154 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -200,15 +172,15 @@ def test_expm1_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
200172 [
201173 (),
202174 (random .randint (1 , 10 ),),
203- (random .randint (1 , 10 ),),
204175 (random .randint (1 , 10 ), random .randint (1 , 10 )),
205176 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
206177 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
207178 ],
208179)
209- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
180+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
210181def test_fac_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
211182 """Test exponent operation across all supported data types."""
183+ check_type_supported (dtype_name )
212184 values = wrapper .randu (shape , dtype_name )
213185 result = wrapper .factorial (values )
214186 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -238,9 +210,10 @@ def test_fac_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
238210 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
239211 ],
240212)
241- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
213+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
242214def test_lgamma_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
243215 """Test lgamma operation across all supported data types."""
216+ check_type_supported (dtype_name )
244217 values = wrapper .randu (shape , dtype_name )
245218 result = wrapper .lgamma (values )
246219 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -264,15 +237,15 @@ def test_lgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
264237 [
265238 (),
266239 (random .randint (1 , 10 ),),
267- (random .randint (1 , 10 ),),
268240 (random .randint (1 , 10 ), random .randint (1 , 10 )),
269241 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
270242 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
271243 ],
272244)
273- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
245+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
274246def test_log_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
275247 """Test log operation across all supported data types."""
248+ check_type_supported (dtype_name )
276249 values = wrapper .randu (shape , dtype_name )
277250 result = wrapper .log (values )
278251 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -296,15 +269,15 @@ def test_log_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
296269 [
297270 (),
298271 (random .randint (1 , 10 ),),
299- (random .randint (1 , 10 ),),
300272 (random .randint (1 , 10 ), random .randint (1 , 10 )),
301273 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
302274 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
303275 ],
304276)
305- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
277+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
306278def test_log10_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
307279 """Test log10 operation across all supported data types."""
280+ check_type_supported (dtype_name )
308281 values = wrapper .randu (shape , dtype_name )
309282 result = wrapper .log10 (values )
310283 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -328,15 +301,15 @@ def test_log10_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
328301 [
329302 (),
330303 (random .randint (1 , 10 ),),
331- (random .randint (1 , 10 ),),
332304 (random .randint (1 , 10 ), random .randint (1 , 10 )),
333305 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
334306 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
335307 ],
336308)
337- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
309+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
338310def test_log1p_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
339311 """Test natural logarithm of 1 + input operation across all supported data types."""
312+ check_type_supported (dtype_name )
340313 values = wrapper .randu (shape , dtype_name )
341314 result = wrapper .log1p (values )
342315 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -360,15 +333,15 @@ def test_log1p_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
360333 [
361334 (),
362335 (random .randint (1 , 10 ),),
363- (random .randint (1 , 10 ),),
364336 (random .randint (1 , 10 ), random .randint (1 , 10 )),
365337 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
366338 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
367339 ],
368340)
369- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
341+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
370342def test_log2_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
371343 """Test log2 operation across all supported data types."""
344+ check_type_supported (dtype_name )
372345 values = wrapper .randu (shape , dtype_name )
373346 result = wrapper .log2 (values )
374347 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -392,15 +365,15 @@ def test_log2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
392365 [
393366 (),
394367 (random .randint (1 , 10 ),),
395- (random .randint (1 , 10 ),),
396368 (random .randint (1 , 10 ), random .randint (1 , 10 )),
397369 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
398370 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
399371 ],
400372)
401- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
373+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
402374def test_pow_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
403375 """Test power operation across all supported data types."""
376+ check_type_supported (dtype_name )
404377 lhs = wrapper .randu (shape , dtype_name )
405378 rhs = wrapper .randu (shape , dtype_name )
406379 result = wrapper .pow (lhs , rhs )
@@ -425,15 +398,15 @@ def test_pow_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
425398 [
426399 (),
427400 (random .randint (1 , 10 ),),
428- (random .randint (1 , 10 ),),
429401 (random .randint (1 , 10 ), random .randint (1 , 10 )),
430402 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
431403 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
432404 ],
433405)
434- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
406+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
435407def test_root_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
436408 """Test root operation across all supported data types."""
409+ check_type_supported (dtype_name )
437410 lhs = wrapper .randu (shape , dtype_name )
438411 rhs = wrapper .randu (shape , dtype_name )
439412 result = wrapper .root (lhs , rhs )
@@ -464,9 +437,10 @@ def test_root_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
464437 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
465438 ],
466439)
467- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
440+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
468441def test_pow2_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
469442 """Test 2 to power operation across all supported data types."""
443+ check_type_supported (dtype_name )
470444 values = wrapper .randu (shape , dtype_name )
471445 result = wrapper .pow2 (values )
472446 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -490,15 +464,15 @@ def test_pow2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
490464 [
491465 (),
492466 (random .randint (1 , 10 ),),
493- (random .randint (1 , 10 ),),
494467 (random .randint (1 , 10 ), random .randint (1 , 10 )),
495468 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
496469 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
497470 ],
498471)
499- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
472+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
500473def test_rsqrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
501474 """Test reciprocal square root operation across all supported data types."""
475+ check_type_supported (dtype_name )
502476 values = wrapper .randu (shape , dtype_name )
503477 result = wrapper .rsqrt (values )
504478 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -522,15 +496,15 @@ def test_rsqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
522496 [
523497 (),
524498 (random .randint (1 , 10 ),),
525- (random .randint (1 , 10 ),),
526499 (random .randint (1 , 10 ), random .randint (1 , 10 )),
527500 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
528501 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
529502 ],
530503)
531- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
504+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
532505def test_sqrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
533506 """Test square root operation across all supported data types."""
507+ check_type_supported (dtype_name )
534508 values = wrapper .randu (shape , dtype_name )
535509 result = wrapper .sqrt (values )
536510 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -554,15 +528,15 @@ def test_sqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
554528 [
555529 (),
556530 (random .randint (1 , 10 ),),
557- (random .randint (1 , 10 ),),
558531 (random .randint (1 , 10 ), random .randint (1 , 10 )),
559532 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
560533 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
561534 ],
562535)
563- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
536+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
564537def test_tgamma_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
565538 """Test gamma operation across all supported data types."""
539+ check_type_supported (dtype_name )
566540 values = wrapper .randu (shape , dtype_name )
567541 result = wrapper .tgamma (values )
568542 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -586,15 +560,15 @@ def test_tgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
586560 [
587561 (),
588562 (random .randint (1 , 10 ),),
589- (random .randint (1 , 10 ),),
590563 (random .randint (1 , 10 ), random .randint (1 , 10 )),
591564 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
592565 (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
593566 ],
594567)
595- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
568+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
596569def test_sigmoid_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
597570 """Test sigmoid operation across all supported data types."""
571+ check_type_supported (dtype_name )
598572 values = wrapper .randu (shape , dtype_name )
599573 result = wrapper .sigmoid (values )
600574 assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
0 commit comments