@@ -31,6 +31,27 @@ def _create_array(buf, numdims, idims, dtype, is_device):
3131 numdims , ct .pointer (c_dims ), dtype .value ))
3232 return out_arr
3333
34+ def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
35+ out_arr = ct .c_void_p (0 )
36+ c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
37+ if offset is None :
38+ offset = 0
39+ offset = ct .c_ulonglong (offset )
40+ if strides is None :
41+ strides = (1 , idims [0 ], idims [0 ]* idims [1 ], idims [0 ]* idims [1 ]* idims [2 ])
42+ while len (strides ) < 4 :
43+ strides = strides + (strides [- 1 ],)
44+ strides = dim4 (strides [0 ], strides [1 ], strides [2 ], strides [3 ])
45+ if is_device :
46+ location = Source .device
47+ else :
48+ location = Source .host
49+ safe_call (backend .get ().af_create_strided_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
50+ offset , numdims , ct .pointer (c_dims ),
51+ ct .pointer (strides ), dtype .value ,
52+ location .value ))
53+ return out_arr
54+
3455def _create_empty_array (numdims , idims , dtype ):
3556 out_arr = ct .c_void_p (0 )
3657 c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -352,7 +373,7 @@ class Array(BaseArray):
352373
353374 """
354375
355- def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False ):
376+ def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False , offset = None , strides = None ):
356377
357378 super (Array , self ).__init__ ()
358379
@@ -409,8 +430,10 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
409430 if (type_char is not None and
410431 type_char != _type_char ):
411432 raise TypeError ("Can not create array of requested type from input data type" )
412-
413- self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
433+ if (offset is None and strides is None ):
434+ self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
435+ else :
436+ self .arr = _create_strided_array (buf , numdims , idims , to_dtype [_type_char ], is_device , offset , strides )
414437
415438 else :
416439
@@ -454,6 +477,26 @@ def __del__(self):
454477 backend .get ().af_release_array (self .arr )
455478
456479 def device_ptr (self ):
480+ """
481+ Return the device pointer exclusively held by the array.
482+
483+ Returns
484+ ------
485+ ptr : int
486+ Contains location of the device pointer
487+
488+ Note
489+ ----
490+ - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491+ - No other arrays will share the same device pointer.
492+ - A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493+ - In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494+ """
495+ ptr = ct .c_void_p (0 )
496+ backend .get ().af_get_device_ptr (ct .pointer (ptr ), self .arr )
497+ return ptr .value
498+
499+ def raw_ptr (self ):
457500 """
458501 Return the device pointer held by the array.
459502
@@ -466,11 +509,45 @@ def device_ptr(self):
466509 ----
467510 - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
468511 - No mem copy is peformed, this function returns the raw device pointer.
512+ - This pointer may be shared with other arrays. Use this function with caution.
513+ - In particular the JIT compiler will not be aware of the shared arrays.
514+ - This results in JITed operations not being immediately visible through the other array.
469515 """
470516 ptr = ct .c_void_p (0 )
471- backend .get ().af_get_device_ptr (ct .pointer (ptr ), self .arr )
517+ backend .get ().af_get_raw_ptr (ct .pointer (ptr ), self .arr )
472518 return ptr .value
473519
520+ def offset (self ):
521+ """
522+ Return the offset, of the first element relative to the raw pointer.
523+
524+ Returns
525+ ------
526+ offset : int
527+ The offset in number of elements
528+ """
529+ offset = ct .c_longlong (0 )
530+ safe_call (backend .get ().af_get_offset (ct .pointer (offset ), self .arr ))
531+ return offset .value
532+
533+ def strides (self ):
534+ """
535+ Return the distance in bytes between consecutive elements for each dimension.
536+
537+ Returns
538+ ------
539+ strides : tuple
540+ The strides for each dimension
541+ """
542+ s0 = ct .c_longlong (0 )
543+ s1 = ct .c_longlong (0 )
544+ s2 = ct .c_longlong (0 )
545+ s3 = ct .c_longlong (0 )
546+ safe_call (backend .get ().af_get_strides (ct .pointer (s0 ), ct .pointer (s1 ),
547+ ct .pointer (s2 ), ct .pointer (s3 ), self .arr ))
548+ strides = (s0 .value ,s1 .value ,s2 .value ,s3 .value )
549+ return strides [:self .numdims ()]
550+
474551 def elements (self ):
475552 """
476553 Return the number of elements in the array.
@@ -622,6 +699,22 @@ def is_bool(self):
622699 safe_call (backend .get ().af_is_bool (ct .pointer (res ), self .arr ))
623700 return res .value
624701
702+ def is_linear (self ):
703+ """
704+ Check if all elements of the array are contiguous.
705+ """
706+ res = ct .c_bool (False )
707+ safe_call (backend .get ().af_is_linear (ct .pointer (res ), self .arr ))
708+ return res .value
709+
710+ def is_owner (self ):
711+ """
712+ Check if the array owns the raw pointer or is a derived array.
713+ """
714+ res = ct .c_bool (False )
715+ safe_call (backend .get ().af_is_owner (ct .pointer (res ), self .arr ))
716+ return res .value
717+
625718 def __add__ (self , other ):
626719 """
627720 Return self + other.
@@ -892,6 +985,12 @@ def __getitem__(self, key):
892985 try :
893986 out = Array ()
894987 n_dims = self .numdims ()
988+
989+ if (isinstance (key , Array ) and key .type () == Dtype .b8 .value ):
990+ n_dims = 1
991+ if (count (key ) == 0 ):
992+ return out
993+
895994 inds = _get_indices (key )
896995
897996 safe_call (backend .get ().af_index_gen (ct .pointer (out .arr ),
@@ -912,9 +1011,21 @@ def __setitem__(self, key, val):
9121011 try :
9131012 n_dims = self .numdims ()
9141013
1014+ is_boolean_idx = isinstance (key , Array ) and key .type () == Dtype .b8 .value
1015+
1016+ if (is_boolean_idx ):
1017+ n_dims = 1
1018+ num = count (key )
1019+ if (num == 0 ):
1020+ return
1021+
9151022 if (_is_number (val )):
9161023 tdims = _get_assign_dims (key , self .dims ())
917- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1024+ if (is_boolean_idx ):
1025+ n_dims = 1
1026+ other_arr = constant_array (val , int (num ), dtype = self .type ())
1027+ else :
1028+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
9181029 del_other = True
9191030 else :
9201031 other_arr = val .arr
@@ -924,8 +1035,8 @@ def __setitem__(self, key, val):
9241035 inds = _get_indices (key )
9251036
9261037 safe_call (backend .get ().af_assign_gen (ct .pointer (out_arr ),
927- self .arr , ct .c_longlong (n_dims ), inds .pointer ,
928- other_arr ))
1038+ self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1039+ other_arr ))
9291040 safe_call (backend .get ().af_release_array (self .arr ))
9301041 if del_other :
9311042 safe_call (backend .get ().af_release_array (other_arr ))
@@ -1142,5 +1253,5 @@ def read_array(filename, index=None, key=None):
11421253
11431254 return out
11441255
1145- from .algorithm import sum
1256+ from .algorithm import ( sum , count )
11461257from .arith import cast
0 commit comments