@@ -43,7 +43,7 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4343# add methods to support ArrayInterface
4444
4545"""
46- OptionallyStaticUnitRange{T<:Integer} (start, stop) <: OrdinalRange{T,T }
46+ OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int }
4747
4848This range permits diverse representations of arrays to comunicate common information
4949about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
@@ -67,21 +67,15 @@ struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRang
6767 end
6868 end
6969
70- function OptionallyStaticUnitRange (x:: AbstractRange )
70+ function OptionallyStaticUnitRange (x:: AbstractRange )
7171 if step (x) == 1
72- fst = static_first (x)
73- lst = static_last (x)
74- return OptionallyStaticUnitRange (fst, lst)
72+ return OptionallyStaticUnitRange (static_first (x), static_last (x))
7573 else
7674 throw (ArgumentError (" step must be 1, got $(step (r)) " ))
7775 end
7876 end
7977end
8078
81- Base.:(:)(L:: Integer , :: StaticInt{U} ) where {U} = OptionallyStaticUnitRange (L, StaticInt (U))
82- Base.:(:)(:: StaticInt{L} , U:: Integer ) where {L} = OptionallyStaticUnitRange (StaticInt (L), U)
83- Base.:(:)(:: StaticInt{L} , :: StaticInt{U} ) where {L,U} = OptionallyStaticUnitRange (StaticInt (L), StaticInt (U))
84-
8579Base. first (r:: OptionallyStaticUnitRange ) = r. start
8680Base. step (:: OptionallyStaticUnitRange ) = StaticInt (1 )
8781Base. last (r:: OptionallyStaticUnitRange ) = r. stop
@@ -90,6 +84,110 @@ known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F
9084known_step (:: Type{<:OptionallyStaticUnitRange} ) = 1
9185known_last (:: Type{<:OptionallyStaticUnitRange{<:Any,StaticInt{L}}} ) where {L} = L
9286
87+ """
88+ OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int}
89+
90+ Similar to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits
91+ a combination of static and standard primitive `Int`s to construct a range. It
92+ specifically enables the use of ranges without a step size of 1. It may be constructed
93+ through the use of `OptionallyStaticStepRange` directly or using static integers with
94+ the range operatore (i.e. `:`).
95+
96+ ```julia
97+ julia> using ArrayInterface
98+
99+ julia> x = ArrayInterface.StaticInt(2);
100+
101+ julia> x:x:10
102+ ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
103+
104+ julia> ArrayInterface.OptionallyStaticStepRange(x, x, 10)
105+ ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
106+ ```
107+ """
108+ struct OptionallyStaticStepRange{F <: Integer , S <: Integer , L <: Integer } <: OrdinalRange{Int,Int}
109+ start:: F
110+ step:: S
111+ stop:: L
112+
113+ function OptionallyStaticStepRange (start, step, stop)
114+ if eltype (start) <: Int
115+ if eltype (stop) <: Int
116+ lst = _steprange_last (start, step, stop)
117+ return new {typeof(start),typeof(step),typeof(lst)} (start, step, lst)
118+ else
119+ return OptionallyStaticStepRange (start, step, Int (stop))
120+ end
121+ else
122+ return OptionallyStaticStepRange (Int (start), step, stop)
123+ end
124+ end
125+
126+ function OptionallyStaticStepRange (x:: AbstractRange )
127+ return OptionallyStaticStepRange (static_first (x), static_step (x), static_last (x))
128+ end
129+ end
130+
131+ # to make StepRange constructor inlineable, so optimizer can see `step` value
132+ @inline function _steprange_last (start:: StaticInt , step:: StaticInt , stop:: StaticInt )
133+ return StaticInt (_steprange_last (Int (start), Int (step), Int (stop)))
134+ end
135+ @inline function _steprange_last (start:: Integer , step:: StaticInt , stop:: StaticInt )
136+ if step === one (step)
137+ # we don't need to check the `stop` if we know it acts like a unit range
138+ return stop
139+ else
140+ return _steprange_last (start, Int (step), Int (stop))
141+ end
142+ end
143+ @inline function _steprange_last (start:: Integer , step:: Integer , stop:: Integer )
144+ z = zero (step)
145+ if step === z
146+ throw (ArgumentError (" step cannot be zero" ))
147+ else
148+ if stop == start
149+ return Int (stop)
150+ else
151+ if step > z
152+ if stop > start
153+ return stop - Int (unsigned (stop - start) % step)
154+ else
155+ return Int (start - one (start))
156+ end
157+ else
158+ if stop > start
159+ return Int (start + one (start))
160+ else
161+ return stop + Int (unsigned (start - stop) % - step)
162+ end
163+ end
164+ end
165+ end
166+ end
167+ Base. first (r:: OptionallyStaticStepRange ) = r. start
168+ Base. step (r:: OptionallyStaticStepRange ) = r. step
169+ Base. last (r:: OptionallyStaticStepRange ) = r. stop
170+
171+ known_first (:: Type{<:OptionallyStaticStepRange{StaticInt{F}}} ) where {F} = F
172+ known_step (:: Type{<:OptionallyStaticStepRange{<:Any,StaticInt{S}}} ) where {S} = S
173+ known_last (:: Type{<:OptionallyStaticStepRange{<:Any,<:Any,StaticInt{L}}} ) where {L} = L
174+
175+ Base.:(:)(L:: Integer , :: StaticInt{U} ) where {U} = OptionallyStaticUnitRange (L, StaticInt (U))
176+ Base.:(:)(:: StaticInt{L} , U:: Integer ) where {L} = OptionallyStaticUnitRange (StaticInt (L), U)
177+ Base.:(:)(:: StaticInt{L} , :: StaticInt{U} ) where {L,U} = OptionallyStaticUnitRange (StaticInt (L), StaticInt (U))
178+ Base.:(:)(:: StaticInt{F} , :: StaticInt{S} , :: StaticInt{L} ) where {F,S,L} = OptionallyStaticStepRange (StaticInt (F), StaticInt (S), StaticInt (L))
179+ Base.:(:)(start:: Integer , :: StaticInt{S} , :: StaticInt{L} ) where {S,L} = OptionallyStaticStepRange (start, StaticInt (S), StaticInt (L))
180+ Base.:(:)(:: StaticInt{F} , :: StaticInt{S} , stop:: Integer ) where {F,S} = OptionallyStaticStepRange (StaticInt (F), StaticInt (S), stop)
181+ Base.:(:)(:: StaticInt{F} , step:: Integer , :: StaticInt{L} ) where {F,L} = OptionallyStaticStepRange (StaticInt (F), step, StaticInt (L))
182+ Base.:(:)(start:: Integer , step:: Integer , :: StaticInt{L} ) where {L} = OptionallyStaticStepRange (start, step, StaticInt (L))
183+ Base.:(:)(start:: Integer , :: StaticInt{S} , stop:: Integer ) where {S} = OptionallyStaticStepRange (start, StaticInt (S), stop)
184+ Base.:(:)(:: StaticInt{F} , step:: Integer , stop:: Integer ) where {F} = OptionallyStaticStepRange (StaticInt (F), step, stop)
185+ Base.:(:)(:: StaticInt{F} , :: StaticInt{1} , :: StaticInt{L} ) where {F,L} = OptionallyStaticUnitRange (StaticInt (F), StaticInt (L))
186+ Base.:(:)(start:: Integer , :: StaticInt{1} , :: StaticInt{L} ) where {L} = OptionallyStaticUnitRange (start, StaticInt (L))
187+ Base.:(:)(:: StaticInt{F} , :: StaticInt{1} , stop:: Integer ) where {F} = OptionallyStaticUnitRange (StaticInt (F), stop)
188+ Base.:(:)(start:: Integer , :: StaticInt{1} , stop:: Integer ) = OptionallyStaticUnitRange (start, stop)
189+
190+
93191function Base. isempty (r:: OptionallyStaticUnitRange )
94192 if known_first (r) === oneunit (eltype (r))
95193 return unsafe_isempty_one_to (last (r))
@@ -98,13 +196,29 @@ function Base.isempty(r::OptionallyStaticUnitRange)
98196 end
99197end
100198
199+ function Base. isempty (r:: OptionallyStaticStepRange )
200+ return (r. start != r. stop) & ((r. step > zero (r. step)) != (r. stop > r. start))
201+ end
202+
101203unsafe_isempty_one_to (lst) = lst <= zero (lst)
102204unsafe_isempty_unit_range (fst, lst) = fst > lst
103205
104206unsafe_length_one_to (lst:: Int ) = lst
105- unsafe_length_one_to (:: StaticInt{L} ) where {L} = lst
207+ unsafe_length_one_to (:: StaticInt{L} ) where {L} = L
208+
209+ @inline function unsafe_length_step_range (start:: Int , step:: Int , stop:: Int )
210+ if step > 1
211+ return Base. checked_add (Int (div (unsigned (stop - start), step)), 1 )
212+ elseif step < - 1
213+ return Base. checked_add (Int (div (unsigned (start - stop), - step)), 1 )
214+ elseif step > 0
215+ return Base. checked_add (Int (div (Base. checked_sub (stop, start), step)), 1 )
216+ else
217+ return Base. checked_add (Int (div (Base. checked_sub (rtart, stop), - step)), 1 )
218+ end
219+ end
106220
107- Base . @propagate_inbounds function Base. getindex (r:: OptionallyStaticUnitRange , i:: Integer )
221+ @propagate_inbounds function Base. getindex (r:: OptionallyStaticUnitRange , i:: Integer )
108222 if known_first (r) === oneunit (eltype (r))
109223 return get_index_one_to (r, i)
110224 else
121235
122236@inline function get_index_unit_range (r, i)
123237 val = first (r) + (i - 1 )
124- @boundscheck if (i < 1 ) || ( val > last (r) && val < first (r) )
238+ @boundscheck if (i < 1 ) || val > last (r)
125239 throw (BoundsError (r, i))
126240 end
127241 return convert (eltype (r), val)
@@ -130,28 +244,28 @@ end
130244@inline _try_static (:: StaticInt{N} , :: StaticInt{N} ) where {N} = StaticInt {N} ()
131245@inline _try_static (:: StaticInt{M} , :: StaticInt{N} ) where {M, N} = @assert false " Unequal Indices: StaticInt{$M }() != StaticInt{$N }()"
132246@propagate_inbounds function _try_static (:: StaticInt{N} , x) where {N}
133- @boundscheck begin
134- @assert N == x " Unequal Indices: StaticInt{$N }() != x == $x "
135- end
136- return StaticInt {N} ()
247+ @boundscheck begin
248+ @assert N == x " Unequal Indices: StaticInt{$N }() != x == $x "
249+ end
250+ return StaticInt {N} ()
137251end
138252@propagate_inbounds function _try_static (x, :: StaticInt{N} ) where {N}
139- @boundscheck begin
140- @assert N == x " Unequal Indices: x == $x != StaticInt{$N }()"
141- end
142- return StaticInt {N} ()
253+ @boundscheck begin
254+ @assert N == x " Unequal Indices: x == $x != StaticInt{$N }()"
255+ end
256+ return StaticInt {N} ()
143257end
144258@propagate_inbounds function _try_static (x, y)
145- @boundscheck begin
146- @assert x == y " Unequal Indicess: x == $x != $y == y"
147- end
148- return x
259+ @boundscheck begin
260+ @assert x == y " Unequal Indicess: x == $x != $y == y"
261+ end
262+ return x
149263end
150264
151265# ##
152266# ## length
153267# ##
154- @inline function known_length (:: Type{T} ) where {T<: AbstractUnitRange }
268+ @inline function known_length (:: Type{T} ) where {T<: OptionallyStaticUnitRange }
155269 fst = known_first (T)
156270 lst = known_last (T)
157271 if fst === nothing || lst === nothing
165279 end
166280end
167281
282+ @inline function known_length (:: Type{T} ) where {T<: OptionallyStaticStepRange }
283+ fst = known_first (T)
284+ stp = known_step (T)
285+ lst = known_last (T)
286+ if fst === nothing || stp === nothing || lst === nothing
287+ return nothing
288+ else
289+ if stp === 1
290+ if fst === oneunit (eltype (T))
291+ return unsafe_length_one_to (lst)
292+ else
293+ return unsafe_length_unit_range (fst, lst)
294+ end
295+ else
296+ return unsafe_length_step_range (fst, stp, lst)
297+ end
298+ end
299+ end
300+
168301function Base. length (r:: OptionallyStaticUnitRange )
169302 if isempty (r)
170303 return 0
@@ -177,6 +310,23 @@ function Base.length(r::OptionallyStaticUnitRange)
177310 end
178311end
179312
313+ function Base. length (r:: OptionallyStaticStepRange )
314+ if isempty (r)
315+ return 0
316+ else
317+ if known_step (r) === 1
318+ if known_first (r) === 1
319+ return unsafe_length_one_to (last (r))
320+ else
321+ return unsafe_length_unit_range (first (r), last (r))
322+ end
323+ else
324+ return unsafe_length_step_range (Int (first (r)), Int (step (r)), Int (last (r)))
325+ end
326+ end
327+ end
328+
329+
180330unsafe_length_unit_range (start:: Integer , stop:: Integer ) = Int ((stop - start) + 1 )
181331
182332"""
219369 lst = _try_static (static_last (x), static_last (y))
220370 return Base. Slice (OptionallyStaticUnitRange (fst, lst))
221371end
372+
0 commit comments