Skip to content

Commit d03a5ab

Browse files
author
Christopher Doris
committed
support constructing PyArray from __array_struct__
1 parent e00da91 commit d03a5ab

File tree

3 files changed

+163
-13
lines changed

3 files changed

+163
-13
lines changed

src/cpython/consts.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,22 @@ end
329329
value::Int = 0
330330
weaklist::PyPtr = C_NULL
331331
end
332+
333+
@kwdef struct PyArrayInterface
334+
two::Cint = 0
335+
nd::Cint = 0
336+
typekind::Cchar = 0
337+
itemsize::Cint = 0
338+
flags::Cint = 0
339+
shape::Ptr{Cssize_t} = C_NULL
340+
strides::Ptr{Cssize_t} = C_NULL
341+
data::Ptr{Cvoid} = C_NULL
342+
descr::PyPtr = C_NULL
343+
end
344+
345+
const NPY_ARRAY_C_CONTIGUOUS = 0x0001
346+
const NPY_ARRAY_F_CONTIGUOUS = 0x0002
347+
const NPY_ARRAY_ALIGNED = 0x0100
348+
const NPY_ARRAY_NOTSWAPPED = 0x0200
349+
const NPY_ARRAY_WRITEABLE = 0x0400
350+
const NPY_ARR_HAS_DESCR = 0x0800

src/pywrap/PyArray.jl

Lines changed: 138 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,16 @@ abstract type PyArraySource end
7878
function pyarray_make(::Type{A}, x::Py; array::Bool=true, buffer::Bool=true, copy::Bool=true) where {A<:PyArray}
7979
# TODO: try/catch is SLOW if an error is thrown, think about sending errors via return values instead
8080
A == Union{} && return pyconvert_unconverted()
81-
if array && pyhasattr(x, "__array_struct__")
82-
@debug "not implemented: creating PyArray from __array_struct__"
81+
if array && (xa = pygetattr(x, "__array_struct__", PyNULL); !pyisnull(xa))
82+
try
83+
return pyarray_make(A, x, PyArraySource_ArrayStruct(x, xa))
84+
catch exc
85+
@debug "failed to make PyArray from __array_struct__" exc=exc
86+
end
8387
end
84-
if array && pyhasattr(x, "__array_interface__")
88+
if array && (xi = pygetattr(x, "__array_interface__", PyNULL); !pyisnull(xi))
8589
try
86-
return pyarray_make(A, x, PyArraySource_ArrayInterface(x))
90+
return pyarray_make(A, x, PyArraySource_ArrayInterface(x, xi))
8791
catch exc
8892
@debug "failed to make PyArray from __array_interface__" exc=exc
8993
end
@@ -97,12 +101,16 @@ function pyarray_make(::Type{A}, x::Py; array::Bool=true, buffer::Bool=true, cop
97101
end
98102
if copy && array && pyhasattr(x, "__array__")
99103
y = x.__array__()
100-
if pyhasattr(y, "__array_struct__")
101-
@debug "not implemented: creating PyArray from __array__().__array_struct__"
104+
if (ya = pygetattr(y, "__array_struct__", PyNULL); !pyisnull(ya))
105+
try
106+
return pyarray_make(A, y, PyArraySource_ArrayStruct(y, ya))
107+
catch exc
108+
@debug "failed to make PyArray from __array__().__array_interface__" exc=exc
109+
end
102110
end
103-
if pyhasattr(y, "__array_interface__")
111+
if (yi = pygetattr(y, "__array_interface__", PyNULL); !pyisnull(yi))
104112
try
105-
return pyarray_make(A, y, PyArraySource_ArrayInterface(y))
113+
return pyarray_make(A, y, PyArraySource_ArrayInterface(y, yi))
106114
catch exc
107115
@debug "failed to make PyArray from __array__().__array_interface__" exc=exc
108116
end
@@ -182,8 +190,7 @@ struct PyArraySource_ArrayInterface <: PyArraySource
182190
readonly :: Bool
183191
handle :: Py
184192
end
185-
function PyArraySource_ArrayInterface(x::Py)
186-
d = x.__array_interface__
193+
function PyArraySource_ArrayInterface(x::Py, d::Py=x.__array_interface__)
187194
# offset
188195
# TODO: how is the offset measured?
189196
offset = pyconvert(Int, @py d.get("offset", 0))
@@ -288,15 +295,125 @@ pyarray_get_M(src::PyArraySource_ArrayInterface) = !src.readonly
288295

289296
pyarray_get_handle(src::PyArraySource_ArrayInterface) = src.handle
290297

291-
# TODO: array struct
298+
# array struct
292299

293300
struct PyArraySource_ArrayStruct <: PyArraySource
294301
obj :: Py
295302
capsule :: Py
303+
info :: C.PyArrayInterface
304+
end
305+
function PyArraySource_ArrayStruct(x::Py, capsule::Py=x.__array_struct__)
306+
name = C.PyCapsule_GetName(getptr(capsule))
307+
ptr = C.PyCapsule_GetPointer(getptr(capsule), name)
308+
info = unsafe_load(Ptr{C.PyArrayInterface}(ptr))
309+
@assert info.two == 2
310+
return PyArraySource_ArrayStruct(x, capsule, info)
311+
end
312+
313+
function pyarray_get_R(src::PyArraySource_ArrayStruct)
314+
swapped = !Utils.isflagset(src.info.flags, C.NPY_ARRAY_NOTSWAPPED)
315+
hasdescr = Utils.isflagset(src.info.flags, C.NPY_ARR_HAS_DESCR)
316+
swapped && error("byte-swapping not supported")
317+
kind = src.info.typekind
318+
size = src.info.itemsize
319+
if kind == 98 # b = bool
320+
if size == sizeof(Bool)
321+
return Bool
322+
else
323+
error("bool of this size not supported: $size")
324+
end
325+
elseif kind == 105 # i = int
326+
if size == 1
327+
return Int8
328+
elseif size == 2
329+
return Int16
330+
elseif size == 4
331+
return Int32
332+
elseif size == 8
333+
return Int64
334+
else
335+
error("int of this size not supported: $size")
336+
end
337+
elseif kind == 117 # u = uint
338+
if size == 1
339+
return UInt8
340+
elseif size == 2
341+
return UInt16
342+
elseif size == 4
343+
return UInt32
344+
elseif size == 8
345+
return UInt64
346+
else
347+
error("uint of this size not supported: $size")
348+
end
349+
elseif kind == 102 # f = float
350+
if size == 2
351+
return Float16
352+
elseif size == 4
353+
return Float32
354+
elseif size == 8
355+
return Float64
356+
else
357+
error("float of this size not supported: $size")
358+
end
359+
elseif kind == 99 # c = complex
360+
if size == 4
361+
return ComplexF16
362+
elseif size == 8
363+
return ComplexF32
364+
elseif size == 16
365+
return ComplexF64
366+
end
367+
elseif kind == 109 # m = timedelta
368+
error("timedelta not supported")
369+
elseif kind == 77 # M = datetime
370+
error("datetime not supported")
371+
elseif kind == 79 # O = object
372+
if size == sizeof(C.PyPtr)
373+
return UnsafePyObject
374+
else
375+
error("object pointer of this size not supported: $size")
376+
end
377+
elseif kind == 83 # S = byte string
378+
error("byte strings not supported")
379+
elseif kind == 85 # U = unicode string
380+
mod(size, 4) == 0 || error("unicode size must be a multiple of 4: $size")
381+
return Utils.StaticString{UInt32,div(size, 4)}
382+
elseif kind == 86 # V = void (should have descr)
383+
error("dtype not supported")
384+
else
385+
error("unexpected kind ($(Char(kind)))")
386+
end
387+
@assert false
388+
end
389+
390+
function pyarray_get_ptr(src::PyArraySource_ArrayStruct, ::Type{R}) where {R}
391+
return Ptr{R}(src.info.data)
392+
end
393+
394+
function pyarray_get_N(src::PyArraySource_ArrayStruct)
395+
return Int(src.info.nd)
396+
end
397+
398+
function pyarray_get_size(src::PyArraySource_ArrayStruct, ::Val{N}) where {N}
399+
ptr = src.info.shape
400+
return ntuple(i->Int(unsafe_load(ptr, i)), Val(N))
296401
end
297-
PyArraySource_ArrayStruct(x::Py) = PyArraySource_ArrayStruct(x, x.__array_struct__)
298402

299-
# TODO: buffer protocol
403+
function pyarray_get_strides(src::PyArraySource_ArrayStruct, ::Val{N}, ::Type{R}, size::NTuple{N,Int}) where {N,R}
404+
ptr = src.info.strides
405+
return ntuple(i->Int(unsafe_load(ptr, i)), Val(N))
406+
end
407+
408+
function pyarray_get_M(src::PyArraySource_ArrayStruct)
409+
return Utils.isflagset(src.info.flags, C.NPY_ARRAY_WRITEABLE)
410+
end
411+
412+
function pyarray_get_handle(src::PyArraySource_ArrayStruct)
413+
return src.capsule
414+
end
415+
416+
# buffer protocol
300417

301418
struct PyArraySource_Buffer <: PyArraySource
302419
obj :: Py
@@ -399,6 +516,14 @@ Base.strides(x::PyArray{T,N,M,L,R}) where {T,N,M,L,R} =
399516
error("strides are not a multiple of element size")
400517
end
401518

519+
function Base.showarg(io::IO, x::PyArray{T,N}, toplevel::Bool) where {T, N}
520+
toplevel || print(io, "::")
521+
print(io, "PyArray{")
522+
show(io, T)
523+
print(io, ", ", N, "}")
524+
return
525+
end
526+
402527
@propagate_inbounds Base.getindex(x::PyArray{T,N}, i::Vararg{Int,N}) where {T,N} = pyarray_getindex(x, i...)
403528
@propagate_inbounds Base.getindex(x::PyArray{T,N,M,true}, i::Int) where {T,N,M} = pyarray_getindex(x, i)
404529
@propagate_inbounds Base.getindex(x::PyArray{T,1,M,true}, i::Int) where {T,M} = pyarray_getindex(x, i)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ module Utils
186186

187187
StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x)
188188

189+
Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N
190+
191+
Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i]
192+
193+
Base.codeunit(x::StaticString{T}) where {T} = T
194+
189195
function Base.iterate(x::StaticString, st::Union{Nothing,Tuple}=nothing)
190196
if st === nothing
191197
s = String(x)

0 commit comments

Comments
 (0)