@@ -22,24 +22,26 @@ def _create_array(buf, numdims, idims, dtype):
2222 out_arr = ct .c_void_p (0 )
2323 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
2424 safe_call (backend .get ().af_create_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
25- numdims , ct .pointer (c_dims ), dtype ))
25+ numdims , ct .pointer (c_dims ), dtype . value ))
2626 return out_arr
2727
2828def _create_empty_array (numdims , idims , dtype ):
2929 out_arr = ct .c_void_p (0 )
3030 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
3131 safe_call (backend .get ().af_create_handle (ct .pointer (out_arr ),
32- numdims , ct .pointer (c_dims ), dtype ))
32+ numdims , ct .pointer (c_dims ), dtype . value ))
3333 return out_arr
3434
35- def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
35+ def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = Dtype . f32 ):
3636 """
3737 Internal function to create a C array. Should not be used externall.
3838 """
3939
4040 if not isinstance (dtype , ct .c_int ):
4141 if isinstance (dtype , int ):
4242 dtype = ct .c_int (dtype )
43+ elif isinstance (dtype , Dtype ):
44+ dtype = ct .c_int (dtype .value )
4345 else :
4446 raise TypeError ("Invalid dtype" )
4547
@@ -50,15 +52,15 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
5052 c_real = ct .c_double (val .real )
5153 c_imag = ct .c_double (val .imag )
5254
53- if (dtype != c32 and dtype != c64 ):
54- dtype = c32
55+ if (dtype . value != Dtype . c32 . value and dtype . value != Dtype . c64 . value ):
56+ dtype = Dtype . c32 . value
5557
5658 safe_call (backend .get ().af_constant_complex (ct .pointer (out ), c_real , c_imag ,
57- 4 , ct .pointer (dims ), dtype ))
58- elif dtype == s64 :
59+ 4 , ct .pointer (dims ), dtype ))
60+ elif dtype . value == Dtype . s64 . value :
5961 c_val = ct .c_longlong (val .real )
6062 safe_call (backend .get ().af_constant_long (ct .pointer (out ), c_val , 4 , ct .pointer (dims )))
61- elif dtype == u64 :
63+ elif dtype . value == Dtype . u64 . value :
6264 c_val = ct .c_ulonglong (val .real )
6365 safe_call (backend .get ().af_constant_ulong (ct .pointer (out ), c_val , 4 , ct .pointer (dims )))
6466 else :
@@ -76,7 +78,7 @@ def _binary_func(lhs, rhs, c_func):
7678 ldims = dim4_to_tuple (lhs .dims ())
7779 rty = implicit_dtype (rhs , lhs .type ())
7880 other = Array ()
79- other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty )
81+ other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty . value )
8082 elif not isinstance (rhs , Array ):
8183 raise TypeError ("Invalid parameter to binary function" )
8284
@@ -92,7 +94,7 @@ def _binary_funcr(lhs, rhs, c_func):
9294 rdims = dim4_to_tuple (rhs .dims ())
9395 lty = implicit_dtype (lhs , rhs .type ())
9496 other = Array ()
95- other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty )
97+ other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty . value )
9698 elif not isinstance (lhs , Array ):
9799 raise TypeError ("Invalid parameter to binary function" )
98100
@@ -186,7 +188,7 @@ class Array(BaseArray):
186188 dims : optional: tuple of ints. default: (0,)
187189 - When using the default values of `dims`, the dims are caclulated as `len(src)`
188190
189- dtype: optional: str or ctypes.c_int . default: None.
191+ dtype: optional: str or arrayfire.Dtype . default: None.
190192 - if str, must be one of the following:
191193 - 'f' for float
192194 - 'd' for double
@@ -198,18 +200,18 @@ class Array(BaseArray):
198200 - 'L' for unsigned 64 bit integer
199201 - 'F' for 32 bit complex number
200202 - 'D' for 64 bit complex number
201- - if ctypes.c_int , must be one of the following:
202- - f32 for float
203- - f64 for double
204- - b8 for bool
205- - u8 for unsigned char
206- - s32 for signed 32 bit integer
207- - u32 for unsigned 32 bit integer
208- - s64 for signed 64 bit integer
209- - u64 for unsigned 64 bit integer
210- - c32 for 32 bit complex number
211- - c64 for 64 bit complex number
212- - if None, f32 is assumed
203+ - if arrayfire.Dtype , must be one of the following:
204+ - Dtype. f32 for float
205+ - Dtype. f64 for double
206+ - Dtype. b8 for bool
207+ - Dtype. u8 for unsigned char
208+ - Dtype. s32 for signed 32 bit integer
209+ - Dtype. u32 for unsigned 32 bit integer
210+ - Dtype. s64 for signed 64 bit integer
211+ - Dtype. u64 for unsigned 64 bit integer
212+ - Dtype. c32 for 32 bit complex number
213+ - Dtype. c64 for 64 bit complex number
214+ - if None, Dtype. f32 is assumed
213215
214216 Attributes
215217 -----------
@@ -281,7 +283,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
281283 type_char = None
282284
283285 _type_char = 'f'
284- dtype = f32
285286
286287 backend .lock ()
287288
@@ -318,8 +319,6 @@ def __init__(self, src=None, dims=(0,), dtype=None):
318319
319320 _type_char = type_char
320321
321- print (_type_char )
322-
323322 else :
324323 raise TypeError ("src is an object of unsupported class" )
325324
@@ -389,11 +388,11 @@ def elements(self):
389388
390389 def dtype (self ):
391390 """
392- Return the data type as a ctypes.c_int value.
391+ Return the data type as a arrayfire.Dtype enum value.
393392 """
394- dty = ct .c_int (f32 .value )
393+ dty = ct .c_int (Dtype . f32 .value )
395394 safe_call (backend .get ().af_get_type (ct .pointer (dty ), self .arr ))
396- return dty
395+ return Dtype ( dty . value )
397396
398397 def type (self ):
399398 """
0 commit comments