Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
ArrayInterfaceCore = "0.1.3"
Compat = "3, 4"
IfElse = "0.1"
Static = "0.7"
Static = "0.8"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion lib/ArrayInterfaceOffsetArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
[compat]
ArrayInterface = "5, 6"
OffsetArrays = "1.11"
Static = "0.7"
Static = "0.7, 0.8"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion lib/ArrayInterfaceStaticArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Adapt = "3"
ArrayInterface = "6"
ArrayInterfaceCore = "0.1.21"
ArrayInterfaceStaticArraysCore = "0.1"
Static = "0.7"
Static = "0.8"
StaticArrays = "1.2.5, 1.3, 1.4"
julia = "1.6"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import ArrayInterfaceStaticArraysCore

const CanonicalInt = Union{Int,StaticInt}

function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N}
Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N))
end
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
Expand Down
12 changes: 5 additions & 7 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import ArrayInterfaceCore: known_first, known_step, known_last

using Static
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
permute, invariant_permutation, field_type, reduce_tup, find_first_eq
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType,
SOneTo, SUnitRange

using IfElse

Expand All @@ -43,10 +45,6 @@ _sub1(@nospecialize x) = x - oneunit(x)
Tuple{X.parameters...,Y.parameters...}
end

const CanonicalInt = Union{Int,StaticInt}
canonicalize(x::Integer) = Int(x)
canonicalize(@nospecialize(x::StaticInt)) = x

abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
Expand Down Expand Up @@ -93,10 +91,10 @@ end
@inline static_last(x) = Static.maybe_static(known_last, last, x)
@inline static_step(x) = Static.maybe_static(known_step, step, x)

@inline function _to_cartesian(a, i::CanonicalInt)
@inline function _to_cartesian(a, i::IntType)
@inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i])
end
@inline function _to_linear(a, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
@inline function _to_linear(a, i::Tuple{IntType,Vararg{IntType}})
_strides2int(offsets(a), size_to_strides(size(a), static(1)), i) + static(1)
end

Expand Down
2 changes: 1 addition & 1 deletion src/array_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
end

## getindex
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
@propagate_inbounds Base.getindex(x::ArrayIndex, i::IntType, ii::IntType...) = x[NDIndex(i, ii...)]

@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)
Expand Down
5 changes: 4 additions & 1 deletion src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ Base.keys(x::LazyAxis) = keys(parent(x))

Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T))

function Static.OptionallyStaticUnitRange(x::LazyAxis)
OptionallyStaticUnitRange(static_first(x), static_last(x))
end
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))

ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
Expand Down Expand Up @@ -219,7 +222,7 @@ Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :p
Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent))
Base.to_shape(x::LazyAxis) = Base.length(x)

@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
@propagate_inbounds function Base.getindex(x::LazyAxis, i::IntType)
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
return Int(i)
end
Expand Down
10 changes: 5 additions & 5 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ to `:_`, then `false` is returned.
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
have a name.
"""
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim))
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), IntType(dim))
known_dimnames(x) = known_dimnames(typeof(x))
function known_dimnames(@nospecialize T::Type{<:VecAdjTrans})
(:_, getfield(known_dimnames(parent_type(T)), 1))
Expand Down Expand Up @@ -159,7 +159,7 @@ end
_unknown_dimnames(::Base.HasShape{N}) where {N} = ntuple(Compat.Returns(:_), StaticInt(N))
_unknown_dimnames(::Any) = (:_,)

@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
(dim > N || dim < 1) && return :_
return @inbounds(x[dim])
Expand All @@ -173,7 +173,7 @@ end
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
have a name.
"""
@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim))
@inline dimnames(x, dim) = _dimname(dimnames(x), IntType(dim))
@inline function dimnames(x::Union{PermutedDimsArray,MatAdjTrans})
map(GetIndex{false}(dimnames(parent(x))), to_parent_dims(x))
end
Expand Down Expand Up @@ -214,7 +214,7 @@ end
return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x)))
end
end
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
# for calls such as `dimnames(view(x, :, 1, :))`
(dim > N || dim < 1) && return static(:_)
Expand All @@ -228,7 +228,7 @@ end
This returns the dimension(s) of `x` corresponding to `dim`.
"""
to_dims(x, dim::Colon) = dim
to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
to_dims(x, @nospecialize(dim::IntType)) = dim
to_dims(x, dim::Integer) = Int(dim)
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
Expand Down
69 changes: 23 additions & 46 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@

function known_lastindex(::Type{T}) where {T}
if known_offset1(T) === nothing || known_length(T) === nothing
return nothing
else
return known_length(T) - known_offset1(T) + 1
end
end
known_lastindex(@nospecialize x) = known_lastindex(typeof(x))

@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x)

function Base.first(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = offset1(x)
@inbounds x[start:min((start - one(start)) + n, static_lastindex(x))]
end

function Base.last(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
stop = static_lastindex(x)
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
end

"""
ArrayInterface.to_indices(A, I::Tuple) -> Tuple

Expand Down Expand Up @@ -162,16 +139,16 @@ to_index(::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
@inline to_index(x, i::NDIndex) = getfield(i, 1)
@inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(_sub1(canonicalize(i.x)), static_last(x))
static_first(x):min(_sub1(IntType(i.x)), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(canonicalize(i.x), static_last(x))
static_first(x):min(IntType(i.x), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
max(canonicalize(i.x), static_first(x)):static_last(x)
max(IntType(i.x), static_first(x)):static_last(x)
end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
max(_add1(IntType(i.x)), static_first(x)):static_last(x)
end
# integer indexing
to_index(x, i::AbstractArray{<:Integer}) = i
Expand Down Expand Up @@ -232,7 +209,7 @@ indices calling [`to_axis`](@ref).
end
end
# drop this dimension
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{<:IntType,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
Expand Down Expand Up @@ -309,15 +286,15 @@ function unsafe_getindex(a::A) where {A}
end

# TODO Need to manage index transformations between nested layers of arrays
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
function unsafe_getindex(a::A, i::IntType) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, _to_cartesian(a, i)...)
end
end
function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
function unsafe_getindex(a::A, i::IntType, ii::Vararg{IntType}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_getindex(a, _to_linear(a, (i, ii...)))
else
Expand All @@ -329,24 +306,24 @@ end
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)

unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
@inline function unsafe_getindex(A::Array, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(A::Array, i::IntType) = Base.arrayref(false, A, Int(i))
@inline function unsafe_getindex(A::Array, i::IntType, ii::Vararg{IntType})
unsafe_getindex(A, _to_linear(A, (i, ii...)))
end

unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
unsafe_getindex(A::LinearIndices, i::IntType) = Int(i)
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{IntType,N}) where {N} = CartesianIndex(ii...)
unsafe_getindex(A::CartesianIndices, ii::Vararg{IntType}) =
unsafe_getindex(A, Base.front(ii)...)
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::CartesianIndices, i::IntType) = @inbounds(A[i])

unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
function unsafe_getindex(A::ReshapedArray, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(A::ReshapedArray, i::IntType) = @inbounds(parent(A)[i])
function unsafe_getindex(A::ReshapedArray, i::IntType, ii::Vararg{IntType})
@inbounds(parent(A)[_to_linear(A, (i, ii...))])
end

unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) = @inbounds(A[i, ii...])
unsafe_getindex(A::SubArray, i::IntType) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::IntType, ii::Vararg{IntType}) = @inbounds(A[i, ii...])

# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
#=
Expand All @@ -364,17 +341,17 @@ function unsafe_get_collection(A, inds)
end
return dest
end
_ints2range(x::CanonicalInt) = x:x
_ints2range(x::IntType) = x:x
_ints2range(x::AbstractRange) = x
# apply _ints2range to front N elements
_ints2range_front(::Val{N}, ind, inds...) where {N} =
(_ints2range(ind), _ints2range_front(Val(N - 1), inds...)...)
_ints2range_front(::Val{0}, ind, inds...) = ()
_ints2range_front(::Val{0}) = ()
# get output shape with given indices
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
_output_shape(::IntType, inds...) = _output_shape(inds...)
_output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...)
_output_shape(::CanonicalInt) = ()
_output_shape(::IntType) = ()
_output_shape(x::AbstractRange) = (Base.length(x),)
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
Expand Down Expand Up @@ -426,15 +403,15 @@ function unsafe_setindex!(a::A, v) where {A}
return unsafe_setindex!(parent(a), v)
end
# TODO Need to manage index transformations between nested layers of arrays
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
function unsafe_setindex!(a::A, v, i::IntType) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, _to_cartesian(a, i)...)
end
end
function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
function unsafe_setindex!(a::A, v, i::IntType, ii::Vararg{IntType}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_setindex!(a, v, _to_linear(a, (i, ii...)))
else
Expand All @@ -446,7 +423,7 @@ end
function unsafe_setindex!(A::Array{T}, v) where {T}
Base.arrayset(false, A, convert(T, v)::T, 1)
end
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
function unsafe_setindex!(A::Array{T}, v, i::IntType) where {T}
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
end

Expand Down
Loading