Skip to content

Commit 38777d1

Browse files
author
Christopher Doris
committed
nicer eltypes for PyArray
1 parent fbc0bc3 commit 38777d1

File tree

3 files changed

+57
-18
lines changed

3 files changed

+57
-18
lines changed

docs/src/releasenotes.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Release Notes
22

3+
## Unreleased
4+
* In `PyArray{T}(x)`, the eltype `T` no longer needs to exactly match the stored data type.
5+
If `x` has numeric elements, then any number type `T` is allowed. If `x` has string
6+
elements, then any string type `T` is allowed.
7+
38
## 0.9.10 (2022-12-02)
49
* Bug fixes.
510

src/pywrap/PyArray.jl

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -614,22 +614,55 @@ pyarray_offset(x::PyArray{T,1,M,true}, i::Int) where {T,M} = (i - 1) .* x.stride
614614
pyarray_offset(x::PyArray{T,N}, i::Vararg{Int,N}) where {T,N} = sum((i .- 1) .* x.strides)
615615
pyarray_offset(x::PyArray{T,0}) where {T} = 0
616616

617-
pyarray_load(::Type{R}, p::Ptr{R}) where {R} = unsafe_load(p)
618-
pyarray_load(::Type{T}, p::Ptr{UnsafePyObject}) where {T} = begin
619-
u = unsafe_load(p)
620-
o = u.ptr == C_NULL ? pynew(Py(nothing)) : pynew(incref(u.ptr))
621-
T == Py ? o : pyconvert(T, o)
617+
function pyarray_load(::Type{T}, p::Ptr{R}) where {T,R}
618+
if R == T
619+
unsafe_load(p)
620+
elseif R == UnsafePyObject
621+
u = unsafe_load(p)
622+
o = u.ptr == C_NULL ? pynew(Py(nothing)) : pynew(incref(u.ptr))
623+
T == Py ? o : pyconvert(T, o)
624+
else
625+
convert(T, unsafe_load(p))
626+
end
622627
end
623628

624-
pyarray_store!(p::Ptr{R}, x::R) where {R} = unsafe_store!(p, x)
625-
pyarray_store!(p::Ptr{UnsafePyObject}, x::UnsafePyObject) = unsafe_store!(p, x)
626-
pyarray_store!(p::Ptr{UnsafePyObject}, x) = @autopy x begin
627-
decref(unsafe_load(p).ptr)
628-
unsafe_store!(p, UnsafePyObject(incref(getptr(x_))))
629+
function pyarray_store!(p::Ptr{R}, x::T) where {R,T}
630+
if R == T
631+
unsafe_store!(p, x)
632+
elseif R == UnsafePyObject
633+
@autopy x begin
634+
decref(unsafe_load(p).ptr)
635+
unsafe_store!(p, UnsafePyObject(incref(getptr(x_))))
636+
end
637+
else
638+
unsafe_store!(p, convert(R, x))
639+
end
629640
end
630641

631-
pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1} = T0 <: R <: T1 ? R : error("not possible")
632-
pyarray_get_T(::Type{UnsafePyObject}, ::Type{T0}, ::Type{T1}) where {T0,T1} = T0 <: Py <: T1 ? Py : T1
642+
function pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1}
643+
if R == UnsafePyObject
644+
if T0 <: Py <: T1
645+
Py
646+
else
647+
T1
648+
end
649+
elseif T0 <: R <: T1
650+
R
651+
else
652+
error("impossible")
653+
end
654+
end
633655

634-
pyarray_check_T(::Type{T}, ::Type{R}) where {T,R} = T == R ? nothing : error("invalid eltype T=$T for raw eltype R=$R")
635-
pyarray_check_T(::Type{T}, ::Type{UnsafePyObject}) where {T} = nothing
656+
function pyarray_check_T(::Type{T}, ::Type{R}) where {T,R}
657+
if R == UnsafePyObject
658+
nothing
659+
elseif T == R
660+
nothing
661+
elseif T <: Number && R <: Number
662+
nothing
663+
elseif T <: AbstractString && R <: AbstractString
664+
nothing
665+
else
666+
error("invalid eltype T=$T for raw eltype R=$R")
667+
end
668+
end

src/utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,11 @@ module Utils
169169
StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits)
170170
end
171171

172-
function Base.print(io::IO, x::StaticString)
173-
cs = collect(x.codeunits)
174-
i = findfirst(==(0), cs)
175-
print(io, transcode(String, i===nothing ? cs : cs[1:i-1]))
172+
function Base.String(x::StaticString{T,N}) where {T,N}
173+
i = findfirst(iszero, x.codeunits)
174+
j = i === nothing ? N : i - 1
175+
cs = T[x.codeunits[i] for i in 1:j]
176+
transcode(String, cs)
176177
end
177178

178179
function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N}

0 commit comments

Comments
 (0)