@@ -22,16 +22,6 @@ function Base.last(x::AbstractVector, n::StaticInt)
2222 @inbounds x[max (offset1 (x), (stop + one (stop)) - n): stop]
2323end
2424
25- function _is_splat (:: Type{I} , i:: StaticInt ) where {I}
26- if dynamic (is_splat_index (field_type (I, i)))
27- True ()
28- else
29- False ()
30- end
31- end
32-
33- _ndims_index (:: Type{I} , i:: StaticInt ) where {I} = StaticInt (ndims_index (field_type (I, i)))
34-
3525"""
3626 to_indices(A, I::Tuple) -> Tuple
3727
@@ -91,69 +81,92 @@ This implementation differs from that of `Base.to_indices` in the following ways
9181"""
9282to_indices (A, :: Tuple{} ) = ()
9383@inline function to_indices (a:: A , inds:: I ) where {A,I}
94- _to_indices (
95- a,
96- inds,
97- IndexStyle (A),
98- static (ndims (A)),
99- eachop (_ndims_index, nstatic (Val (known_length (I))), I),
100- eachop (_is_splat, nstatic (Val (known_length (I))), I)
101- )
102- end
103- @generated function _to_indices (A, inds:: I , :: S , :: StaticInt{N} , :: NDI , :: IS ) where {I,S,N,NDI,IS}
104- cnt = zeros (Int, known_length (NDI))
105- splat_position = 0
106- remaining = N
107- for i in 1 : known_length (NDI)
108- ndi = known (NDI. parameters[i])
109- splat = known (IS. parameters[i])
110- if splat && splat_position === 0
111- splat_position = i
112- else
113- remaining -= ndi
114- cnt[i] = ndi
115- end
116- end
117- if splat_position != = 0
118- cnt[splat_position] = max (0 , remaining)
84+ _to_indices (a, inds, IndexStyle (A), static (ndims (A)), IndicesInfo (I))
85+ end
86+ @generated function _to_indices (a, inds, :: S , :: StaticInt{N} , :: IndicesInfo{NI,NS,IS} ) where {S,N,NI,NS,IS}
87+ _to_indices_expr (S, N, NI, NS, IS)
88+ end
89+ function _to_indices_expr (S:: DataType , N:: Int , ni, ns, is)
90+ blk = Expr (:block , Expr (:meta , :inline ))
91+ # check to see if we are dealing with linear indexing over a multidimensional array
92+ if length (ni) == 1 && ni[1 ] === 1
93+ push! (blk. args, :((to_index (LazyAxis {:} (a), getfield (inds, 1 )),)))
11994 else
120- # if there are additional trailing dimensions not consumed by the index then we have
121- # to assume it's linear indexing or that these are trailing dimensions.
122- cnt[end ] += max (0 , remaining)
123- end
95+ indsexpr = Expr (:tuple )
96+ ndi = Int[]
97+ nds = Int[]
98+ isi = Bool[]
99+ # 1. unwrap AbstractCartesianIndex, CartesianIndices, Indices
100+ for i in 1 : length (ns)
101+ ns_i = ns[i]
102+ if ns_i isa Tuple
103+ for j in 1 : length (ns_i)
104+ push! (ndi, 1 )
105+ push! (nds, ns_i[j])
106+ push! (isi, false )
107+ push! (indsexpr. args, :(getfield (getfield (getfield (inds, $ i), 1 ), $ j)))
108+ end
109+ else
110+ push! (indsexpr. args, :(getfield (inds, $ i)))
111+ push! (ndi, ni[i])
112+ push! (nds, ns_i)
113+ push! (isi, is[i])
114+ end
115+ end
124116
125- t = Expr (:tuple )
126- dim = 0
127- for i in 1 : known_length (NDI)
128- if i === known_length (NDI) && S <: IndexLinear
129- ICall = :LinearIndices
130- else
131- ICall = :CartesianIndices
117+ # 2. find splat indices
118+ splat_position = 0
119+ remaining = N
120+ for i in eachindex (ndi, nds, isi)
121+ if isi[i] && splat_position == 0
122+ splat_position = i
123+ else
124+ remaining -= ndi[i]
125+ end
132126 end
133- c = cnt[i]
134- iexpr = :(@inbounds (getfield (inds, $ i)):: $ (I. parameters[i]))
135- if dim === N
136- push! (t. args, :(to_index ($ (ICall)(()), $ iexpr)))
137- elseif c === 1
138- dim += 1
139- push! (t. args, :(to_index (@inbounds (getfield (axs, $ dim)), $ iexpr)))
140- else
141- subaxs = Expr (:tuple )
142- for _ in 1 : c
143- if dim < N
127+ if splat_position != = 0
128+ for _ in 2 : remaining
129+ insert! (ndi, splat_position, 1 )
130+ insert! (nds, splat_position, 1 )
131+ insert! (indsexpr. args, splat_position, indsexpr. args[splat_position])
132+ end
133+ end
134+
135+ # 3. insert `to_index` calls
136+ dim = 0
137+ nndi = length (ndi)
138+ for i in 1 : nndi
139+ ndi_i = ndi[i]
140+ if ndi_i == 1
141+ dim += 1
142+ indsexpr. args[i] = :(to_index ($ (_axis_expr (N, dim)), $ (indsexpr. args[i])))
143+ else
144+ subaxs = Expr (:tuple )
145+ for _ in 1 : ndi_i
144146 dim += 1
145- push! (subaxs. args, :(@inbounds (getfield (axs, $ dim))))
147+ push! (subaxs. args, _axis_expr (N, dim))
148+ end
149+ if i == nndi && S <: IndexLinear
150+ indsexpr. args[i] = :(to_index (LinearIndices ($ (subaxs)), $ (indsexpr. args[i])))
151+ else
152+ indsexpr. args[i] = :(to_index (CartesianIndices ($ (subaxs)), $ (indsexpr. args[i])))
146153 end
147154 end
148- push! (t. args, :(to_index ($ (ICall)($ subaxs), $ iexpr)))
149155 end
156+ push! (blk. args, Expr (:(= ), :axs , :(lazy_axes (a))))
157+ push! (blk. args, :(_flatten_tuples ($ (indsexpr))))
158+ end
159+ return blk
160+ end
161+
162+ function _axis_expr (N:: Int , d:: Int )
163+ if d <= N
164+ :(getfield (axs, $ d))
165+ else # ndims(a)+ can only have indices 1:1
166+ :($ (SOneTo (1 )))
150167 end
151- Expr (:block ,
152- Expr (:meta , :inline ),
153- Expr (:(= ), :axs , :(lazy_axes (A))),
154- :(_flatten_tuples ($ t))
155- )
156168end
169+
157170@generated function _flatten_tuples (inds:: I ) where {I}
158171 t = Expr (:tuple )
159172 for i in 1 : known_length (I)
@@ -409,7 +422,7 @@ _output_shape(x::AbstractRange) = (length(x),)
409422end
410423_known_first_isone (ind) = known_first (ind) != = nothing && isone (known_first (ind))
411424@inline function unsafe_get_collection (A:: LinearIndices{N} , inds) where {N}
412- if Base. length (inds) === 1 && isone ( _ndims_index ( typeof (inds), static ( 1 )))
425+ if Base. length (inds) === 1 && ndims_index ( typeof (first ( inds))) === 1
413426 return @inbounds (eachindex (A)[first (inds)])
414427 elseif stride_preserving_index (typeof (inds)) === True () &&
415428 reduce_tup (& , map (_known_first_isone, inds))
@@ -464,7 +477,6 @@ function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) wh
464477 end
465478end
466479
467-
468480function unsafe_setindex! (A:: Array{T} , v) where {T}
469481 Base. arrayset (false , A, convert (T, v):: T , 1 )
470482end
0 commit comments