@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189189 :(ArrayInterface. contiguous_axis),
190190 :(ArrayInterface. contiguous_batch_size),
191191 :(ArrayInterface. device),
192+ :(ArrayInterface. dense_dims),
192193 :(ArrayInterface. stride_rank),
193194 :(VectorizationBase. val_dense_dims),
194195 :(ArrayInterface. offsets),
@@ -204,7 +205,9 @@ function is_column_major(x)
204205 true
205206end
206207is_row_major (x) = is_column_major (reverse (x))
207- # @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208+ _find_arg_least_greater (r:: Vector{Int} , i) =
209+ findmin (x -> x > i ? x : typemax (Int), r)
210+
208211function _strides_expr (
209212 @nospecialize (s),
210213 @nospecialize (x),
@@ -214,20 +217,19 @@ function _strides_expr(
214217 N = length (R)
215218 q = Expr (:block , Expr (:meta , :inline ))
216219 strd_tup = Expr (:tuple )
220+ resize! (strd_tup. args, N)
217221 ifel = GlobalRef (Core, :ifelse )
218- Nrange = 1 : 1 : N # type stability w/ respect to reverse
222+ Nrange = 1 : N # type stability w/ respect to reverse
223+ # Nrange = 1:1:N # type stability w/ respect to reverse
219224 use_stride_acc = true
220225 stride_acc:: Int = 1
221- if is_column_major (R)
222- # elseif is_row_major(R)
223- # Nrange = reverse(Nrange)
224- else # not worth my time optimizing this case at the moment...
225- # will write something generic stride-rank agnostic eventually
226+ next, n = _find_arg_least_greater (R, 0 )
227+ if ! D[n]
226228 use_stride_acc = false
227229 stride_acc = 0
228230 end
229231 sₙ_value:: Int = 0
230- for n ∈ Nrange
232+ for _n ∈ Nrange
231233 xₙ_type = x[n]
232234 xₙ_static = xₙ_type <: StaticInt
233235 xₙ_value:: Int = xₙ_static ? (xₙ_type. parameters[1 ]):: Int : 0
@@ -236,38 +238,38 @@ function _strides_expr(
236238 if sₙ_static
237239 sₙ_value = s_type. parameters[1 ]
238240 if s_type === One
239- push! ( strd_tup. args, Expr (:call , lv (:Zero ) ))
241+ strd_tup. args[n] = Expr (:call , lv (:Zero ))
240242 elseif stride_acc ≠ 0
241- push! ( strd_tup. args, staticexpr (stride_acc) )
243+ strd_tup. args[n] = staticexpr (stride_acc)
242244 else
243- push! ( strd_tup. args, :($ getfield (x, $ n) ))
245+ strd_tup. args[n] = :($ getfield (x, $ n))
244246 end
245247 else
246248 if xₙ_static
247- push! ( strd_tup. args, staticexpr (xₙ_value) )
249+ strd_tup. args[n] = staticexpr (xₙ_value)
248250 elseif stride_acc ≠ 0
249- push! ( strd_tup. args, staticexpr (stride_acc) )
251+ strd_tup. args[n] = staticexpr (stride_acc)
250252 else
251- push! (
252- strd_tup. args,
253+ strd_tup. args[n] =
253254 :($ ifel (isone ($ getfield (s, $ n)), zero ($ xₙ_type), $ getfield (x, $ n)))
254- )
255255 end
256256 end
257- if (n ≠ last (Nrange)) && use_stride_acc
258- nnext = n + step (Nrange)
259- if D[nnext]
260- if xₙ_static & sₙ_static
261- stride_acc = xₙ_value * sₙ_value
262- elseif sₙ_static
263- if stride_acc ≠ 0
264- stride_acc *= sₙ_value
257+ if (_n ≠ N)
258+ next, n = _find_arg_least_greater (R, next)
259+ if use_stride_acc
260+ if D[n]
261+ if xₙ_static & sₙ_static
262+ stride_acc = xₙ_value * sₙ_value
263+ elseif sₙ_static
264+ if stride_acc ≠ 0
265+ stride_acc *= sₙ_value
266+ end
267+ else
268+ stride_acc = 0
265269 end
266270 else
267271 stride_acc = 0
268272 end
269- else
270- stride_acc = 0
271273 end
272274 end
273275 end
0 commit comments