Skip to content

Commit 6875bc3

Browse files
committed
FEAT: Adding the ability to get, lock, and unlock raw device pointers
1 parent 60fc4ac commit 6875bc3

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

arrayfire/array.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@
2020
from .index import *
2121
from .index import _Index4
2222

23-
def _create_array(buf, numdims, idims, dtype):
23+
def _create_array(buf, numdims, idims, dtype, is_device):
2424
out_arr = ct.c_void_p(0)
2525
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
26-
safe_call(backend.get().af_create_array(ct.pointer(out_arr), ct.c_void_p(buf),
27-
numdims, ct.pointer(c_dims), dtype.value))
26+
if (not is_device):
27+
safe_call(backend.get().af_create_array(ct.pointer(out_arr), ct.c_void_p(buf),
28+
numdims, ct.pointer(c_dims), dtype.value))
29+
else:
30+
safe_call(backend.get().af_device_array(ct.pointer(out_arr), ct.c_void_p(buf),
31+
numdims, ct.pointer(c_dims), dtype.value))
2832
return out_arr
2933

3034
def _create_empty_array(numdims, idims, dtype):
@@ -348,7 +352,7 @@ class Array(BaseArray):
348352
349353
"""
350354

351-
def __init__(self, src=None, dims=(0,), dtype=None):
355+
def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
352356

353357
super(Array, self).__init__()
354358

@@ -385,7 +389,8 @@ def __init__(self, src=None, dims=(0,), dtype=None):
385389
_type_char = tmp.typecode
386390
numdims, idims = _get_info(dims, buf_len)
387391
elif isinstance(src, int) or isinstance(src, ct.c_void_p):
388-
buf = src
392+
buf = src if not isinstance(src, ct.c_void_p) else src.value
393+
389394
numdims, idims = _get_info(dims, buf_len)
390395

391396
elements = 1
@@ -407,7 +412,7 @@ def __init__(self, src=None, dims=(0,), dtype=None):
407412
type_char != _type_char):
408413
raise TypeError("Can not create array of requested type from input data type")
409414

410-
self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char])
415+
self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device)
411416

412417
else:
413418

arrayfire/device.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,55 @@ def device_gc():
138138
Ask the garbage collector to free all unlocked memory
139139
"""
140140
safe_call(backend.get().af_device_gc())
141+
142+
def get_device_ptr(a):
143+
"""
144+
Get the raw device pointer of an array
145+
146+
Parameters
147+
----------
148+
a: af.Array
149+
- A multi dimensional arrayfire array.
150+
151+
Returns
152+
-------
153+
- internal device pointer held by a
154+
155+
Note
156+
-----
157+
- The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
158+
- This function enables the user to interoperate arrayfire with other CUDA/OpenCL/C libraries.
159+
160+
"""
161+
ptr = ct.c_void_p(0)
162+
safe_call(backend.get().af_get_device_ptr(ct.pointer(ptr), a.arr))
163+
return ptr
164+
165+
def lock_device_ptr(a):
166+
"""
167+
Ask arrayfire to not perform garbage collection on raw data held by an array.
168+
169+
Parameters
170+
----------
171+
a: af.Array
172+
- A multi dimensional arrayfire array.
173+
174+
Note
175+
-----
176+
- The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
177+
"""
178+
ptr = ct.c_void_p(0)
179+
safe_call(backend.get().af_lock_device_ptr(a.arr))
180+
181+
def unlock_device_ptr(a):
182+
"""
183+
Tell arrayfire to resume garbage collection on raw data held by an array.
184+
185+
Parameters
186+
----------
187+
a: af.Array
188+
- A multi dimensional arrayfire array.
189+
190+
"""
191+
ptr = ct.c_void_p(0)
192+
safe_call(backend.get().af_unlock_device_ptr(a.arr))

tests/simple/device.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,13 @@ def simple_device(verbose=False):
4040

4141
af.set_device(dev)
4242

43+
a = af.randu(10,10)
44+
display_func(a)
45+
dev_ptr = af.get_device_ptr(a)
46+
print_func(dev_ptr)
47+
b = af.Array(src=dev_ptr, dims=a.dims(), dtype=a.dtype(), is_device=True)
48+
display_func(b)
49+
af.lock_device_ptr(b)
50+
af.unlock_device_ptr(b)
51+
4352
_util.tests['device'] = simple_device

0 commit comments

Comments
 (0)