@@ -78,12 +78,16 @@ abstract type PyArraySource end
7878function pyarray_make (:: Type{A} , x:: Py ; array:: Bool = true , buffer:: Bool = true , copy:: Bool = true ) where {A<: PyArray }
7979 # TODO : try/catch is SLOW if an error is thrown, think about sending errors via return values instead
8080 A == Union{} && return pyconvert_unconverted ()
81- if array && pyhasattr (x, " __array_struct__" )
82- @debug " not implemented: creating PyArray from __array_struct__"
81+ if array && (xa = pygetattr (x, " __array_struct__" , PyNULL); ! pyisnull (xa))
82+ try
83+ return pyarray_make (A, x, PyArraySource_ArrayStruct (x, xa))
84+ catch exc
85+ @debug " failed to make PyArray from __array_struct__" exc= exc
86+ end
8387 end
84- if array && pyhasattr ( x, " __array_interface__" )
88+ if array && (xi = pygetattr ( x, " __array_interface__" , PyNULL); ! pyisnull (xi) )
8589 try
86- return pyarray_make (A, x, PyArraySource_ArrayInterface (x))
90+ return pyarray_make (A, x, PyArraySource_ArrayInterface (x, xi ))
8791 catch exc
8892 @debug " failed to make PyArray from __array_interface__" exc= exc
8993 end
@@ -97,12 +101,16 @@ function pyarray_make(::Type{A}, x::Py; array::Bool=true, buffer::Bool=true, cop
97101 end
98102 if copy && array && pyhasattr (x, " __array__" )
99103 y = x. __array__ ()
100- if pyhasattr (y, " __array_struct__" )
101- @debug " not implemented: creating PyArray from __array__().__array_struct__"
104+ if (ya = pygetattr (y, " __array_struct__" , PyNULL); ! pyisnull (ya))
105+ try
106+ return pyarray_make (A, y, PyArraySource_ArrayStruct (y, ya))
107+ catch exc
108+ @debug " failed to make PyArray from __array__().__array_interface__" exc= exc
109+ end
102110 end
103- if pyhasattr ( y, " __array_interface__" )
111+ if (yi = pygetattr ( y, " __array_interface__" , PyNULL); ! pyisnull (yi) )
104112 try
105- return pyarray_make (A, y, PyArraySource_ArrayInterface (y))
113+ return pyarray_make (A, y, PyArraySource_ArrayInterface (y, yi ))
106114 catch exc
107115 @debug " failed to make PyArray from __array__().__array_interface__" exc= exc
108116 end
@@ -182,8 +190,7 @@ struct PyArraySource_ArrayInterface <: PyArraySource
182190 readonly :: Bool
183191 handle :: Py
184192end
185- function PyArraySource_ArrayInterface (x:: Py )
186- d = x. __array_interface__
193+ function PyArraySource_ArrayInterface (x:: Py , d:: Py = x. __array_interface__)
187194 # offset
188195 # TODO : how is the offset measured?
189196 offset = pyconvert (Int, @py d. get (" offset" , 0 ))
@@ -288,15 +295,125 @@ pyarray_get_M(src::PyArraySource_ArrayInterface) = !src.readonly
288295
289296pyarray_get_handle (src:: PyArraySource_ArrayInterface ) = src. handle
290297
291- # TODO : array struct
298+ # array struct
292299
293300struct PyArraySource_ArrayStruct <: PyArraySource
294301 obj :: Py
295302 capsule :: Py
303+ info :: C.PyArrayInterface
304+ end
305+ function PyArraySource_ArrayStruct (x:: Py , capsule:: Py = x. __array_struct__)
306+ name = C. PyCapsule_GetName (getptr (capsule))
307+ ptr = C. PyCapsule_GetPointer (getptr (capsule), name)
308+ info = unsafe_load (Ptr {C.PyArrayInterface} (ptr))
309+ @assert info. two == 2
310+ return PyArraySource_ArrayStruct (x, capsule, info)
311+ end
312+
313+ function pyarray_get_R (src:: PyArraySource_ArrayStruct )
314+ swapped = ! Utils. isflagset (src. info. flags, C. NPY_ARRAY_NOTSWAPPED)
315+ hasdescr = Utils. isflagset (src. info. flags, C. NPY_ARR_HAS_DESCR)
316+ swapped && error (" byte-swapping not supported" )
317+ kind = src. info. typekind
318+ size = src. info. itemsize
319+ if kind == 98 # b = bool
320+ if size == sizeof (Bool)
321+ return Bool
322+ else
323+ error (" bool of this size not supported: $size " )
324+ end
325+ elseif kind == 105 # i = int
326+ if size == 1
327+ return Int8
328+ elseif size == 2
329+ return Int16
330+ elseif size == 4
331+ return Int32
332+ elseif size == 8
333+ return Int64
334+ else
335+ error (" int of this size not supported: $size " )
336+ end
337+ elseif kind == 117 # u = uint
338+ if size == 1
339+ return UInt8
340+ elseif size == 2
341+ return UInt16
342+ elseif size == 4
343+ return UInt32
344+ elseif size == 8
345+ return UInt64
346+ else
347+ error (" uint of this size not supported: $size " )
348+ end
349+ elseif kind == 102 # f = float
350+ if size == 2
351+ return Float16
352+ elseif size == 4
353+ return Float32
354+ elseif size == 8
355+ return Float64
356+ else
357+ error (" float of this size not supported: $size " )
358+ end
359+ elseif kind == 99 # c = complex
360+ if size == 4
361+ return ComplexF16
362+ elseif size == 8
363+ return ComplexF32
364+ elseif size == 16
365+ return ComplexF64
366+ end
367+ elseif kind == 109 # m = timedelta
368+ error (" timedelta not supported" )
369+ elseif kind == 77 # M = datetime
370+ error (" datetime not supported" )
371+ elseif kind == 79 # O = object
372+ if size == sizeof (C. PyPtr)
373+ return UnsafePyObject
374+ else
375+ error (" object pointer of this size not supported: $size " )
376+ end
377+ elseif kind == 83 # S = byte string
378+ error (" byte strings not supported" )
379+ elseif kind == 85 # U = unicode string
380+ mod (size, 4 ) == 0 || error (" unicode size must be a multiple of 4: $size " )
381+ return Utils. StaticString{UInt32,div (size, 4 )}
382+ elseif kind == 86 # V = void (should have descr)
383+ error (" dtype not supported" )
384+ else
385+ error (" unexpected kind ($(Char (kind)) )" )
386+ end
387+ @assert false
388+ end
389+
390+ function pyarray_get_ptr (src:: PyArraySource_ArrayStruct , :: Type{R} ) where {R}
391+ return Ptr {R} (src. info. data)
392+ end
393+
394+ function pyarray_get_N (src:: PyArraySource_ArrayStruct )
395+ return Int (src. info. nd)
396+ end
397+
398+ function pyarray_get_size (src:: PyArraySource_ArrayStruct , :: Val{N} ) where {N}
399+ ptr = src. info. shape
400+ return ntuple (i-> Int (unsafe_load (ptr, i)), Val (N))
296401end
297- PyArraySource_ArrayStruct (x:: Py ) = PyArraySource_ArrayStruct (x, x. __array_struct__)
298402
299- # TODO : buffer protocol
403+ function pyarray_get_strides (src:: PyArraySource_ArrayStruct , :: Val{N} , :: Type{R} , size:: NTuple{N,Int} ) where {N,R}
404+ ptr = src. info. strides
405+ return ntuple (i-> Int (unsafe_load (ptr, i)), Val (N))
406+ end
407+
408+ function pyarray_get_M (src:: PyArraySource_ArrayStruct )
409+ return Utils. isflagset (src. info. flags, C. NPY_ARRAY_WRITEABLE)
410+ end
411+
412+ function pyarray_get_handle (src:: PyArraySource_ArrayStruct )
413+ return src. capsule
414+ end
415+
416+ # buffer protocol
300417
301418struct PyArraySource_Buffer <: PyArraySource
302419 obj :: Py
@@ -399,6 +516,14 @@ Base.strides(x::PyArray{T,N,M,L,R}) where {T,N,M,L,R} =
399516 error (" strides are not a multiple of element size" )
400517 end
401518
519+ function Base. showarg (io:: IO , x:: PyArray{T,N} , toplevel:: Bool ) where {T, N}
520+ toplevel || print (io, " ::" )
521+ print (io, " PyArray{" )
522+ show (io, T)
523+ print (io, " , " , N, " }" )
524+ return
525+ end
526+
402527@propagate_inbounds Base. getindex (x:: PyArray{T,N} , i:: Vararg{Int,N} ) where {T,N} = pyarray_getindex (x, i... )
403528@propagate_inbounds Base. getindex (x:: PyArray{T,N,M,true} , i:: Int ) where {T,N,M} = pyarray_getindex (x, i)
404529@propagate_inbounds Base. getindex (x:: PyArray{T,1,M,true} , i:: Int ) where {T,M} = pyarray_getindex (x, i)
0 commit comments