2020from .index import *
2121from .index import _Index4
2222
23- def _create_array (buf , numdims , idims , dtype ):
23+ def _create_array (buf , numdims , idims , dtype , is_device ):
2424 out_arr = ct .c_void_p (0 )
2525 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
26- safe_call (backend .get ().af_create_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
27- numdims , ct .pointer (c_dims ), dtype .value ))
26+ if (not is_device ):
27+ safe_call (backend .get ().af_create_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
28+ numdims , ct .pointer (c_dims ), dtype .value ))
29+ else :
30+ safe_call (backend .get ().af_device_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
31+ numdims , ct .pointer (c_dims ), dtype .value ))
2832 return out_arr
2933
3034def _create_empty_array (numdims , idims , dtype ):
@@ -348,7 +352,7 @@ class Array(BaseArray):
348352
349353 """
350354
351- def __init__ (self , src = None , dims = (0 ,), dtype = None ):
355+ def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False ):
352356
353357 super (Array , self ).__init__ ()
354358
@@ -385,7 +389,8 @@ def __init__(self, src=None, dims=(0,), dtype=None):
385389 _type_char = tmp .typecode
386390 numdims , idims = _get_info (dims , buf_len )
387391 elif isinstance (src , int ) or isinstance (src , ct .c_void_p ):
388- buf = src
392+ buf = src if not isinstance (src , ct .c_void_p ) else src .value
393+
389394 numdims , idims = _get_info (dims , buf_len )
390395
391396 elements = 1
@@ -407,7 +412,7 @@ def __init__(self, src=None, dims=(0,), dtype=None):
407412 type_char != _type_char ):
408413 raise TypeError ("Can not create array of requested type from input data type" )
409414
410- self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ])
415+ self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
411416
412417 else :
413418
0 commit comments