@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189189 :(ArrayInterface. contiguous_axis),
190190 :(ArrayInterface. contiguous_batch_size),
191191 :(ArrayInterface. device),
192+ :(ArrayInterface. dense_dims),
192193 :(ArrayInterface. stride_rank),
193194 :(VectorizationBase. val_dense_dims),
194195 :(ArrayInterface. offsets),
@@ -204,6 +205,8 @@ function is_column_major(x)
204205 true
205206end
206207is_row_major (x) = is_column_major (reverse (x))
208+ _find_arg_least_greater (r:: Vector{Int} , i) =
209+ findmin (x -> x > i ? x : typemax (Int), r)
207210# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208211function _strides_expr (
209212 @nospecialize (s),
@@ -215,19 +218,18 @@ function _strides_expr(
215218 q = Expr (:block , Expr (:meta , :inline ))
216219 strd_tup = Expr (:tuple )
217220 ifel = GlobalRef (Core, :ifelse )
218- Nrange = 1 : 1 : N # type stability w/ respect to reverse
221+ Nrange = 1 : N # type stability w/ respect to reverse
222+ # Nrange = 1:1:N # type stability w/ respect to reverse
219223 use_stride_acc = true
220224 stride_acc:: Int = 1
221- if is_column_major (R)
222- # elseif is_row_major(R)
223- # Nrange = reverse(Nrange)
224- else # not worth my time optimizing this case at the moment...
225- # will write something generic stride-rank agnostic eventually
225+ next, n = _find_arg_least_greater (R, 0 )
226+ n = findfirst (== (1 ), R)
227+ if ! D[n]
226228 use_stride_acc = false
227229 stride_acc = 0
228230 end
229231 sₙ_value:: Int = 0
230- for n ∈ Nrange
232+ for _n ∈ Nrange
231233 xₙ_type = x[n]
232234 xₙ_static = xₙ_type <: StaticInt
233235 xₙ_value:: Int = xₙ_static ? (xₙ_type. parameters[1 ]):: Int : 0
@@ -254,20 +256,22 @@ function _strides_expr(
254256 )
255257 end
256258 end
257- if (n ≠ last (Nrange)) && use_stride_acc
258- nnext = n + step (Nrange)
259- if D[nnext]
260- if xₙ_static & sₙ_static
261- stride_acc = xₙ_value * sₙ_value
262- elseif sₙ_static
263- if stride_acc ≠ 0
264- stride_acc *= sₙ_value
259+ if (n ≠ N)
260+ next, n = _find_arg_least_greater (R, next)
261+ if use_stride_acc
262+ if D[n]
263+ if xₙ_static & sₙ_static
264+ stride_acc = xₙ_value * sₙ_value
265+ elseif sₙ_static
266+ if stride_acc ≠ 0
267+ stride_acc *= sₙ_value
268+ end
269+ else
270+ stride_acc = 0
265271 end
266272 else
267273 stride_acc = 0
268274 end
269- else
270- stride_acc = 0
271275 end
272276 end
273277 end
675679 :: Val{UNROLL} ,
676680 :: Val{dontbc}
677681) where {T<: NativeTypes ,N,BC<: Union{Broadcasted,Product} ,Mod,UNROLL,dontbc}
682+ @show (dest) (BC)
678683 vmaterialize_fun (sizeof (T), N, BC, Mod, UNROLL, dontbc, false )
679684end
680685@generated function vmaterialize! (
0 commit comments