From 29bc74b2bf3cb19cf2c194c31c0bd06729ea7288 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 17:05:46 -0500 Subject: [PATCH 1/3] feat: protoype for dynamically sized arrays --- src/ConcreteRArray.jl | 25 ++++++++++++++++++++++++- src/TracedRArray.jl | 21 +++++++++++++++++++-- src/TracedUtils.jl | 2 +- src/Types.jl | 2 +- src/mlir/IR/Type.jl | 7 +++++++ 5 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 6b74191866..131b886532 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -93,7 +93,24 @@ end Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:AbstractConcreteArray} = T(x) -Base.size(x::AbstractConcreteArray) = x.shape +function Base.size(x::AbstractConcreteArray) + # TODO: inside compile make it dynamic?? + # return map(Base.Fix1(_size_static_or_dynamic, x), 1:ndims(x)) + return x.shape +end + +function _buffer_size(x::AbstractConcreteArray, dim::Integer) + sz = x.shape[dim] + return sz < 0 ? _size_dynamic(x, dim) : sz +end + +function _size_dynamic(x::ConcretePJRTArray, dim::Integer) + @assert !Sharding.is_sharded(x.sharding) "TODO: support the sharded case. Use IFRT for \ + now." + return size(x.data[1])[ndims(x) - dim + 1] +end + +_size_dynamic(x::ConcreteIFRTArray, dim::Integer) = size(x.data)[ndims(x) - dim + 1] function Base.isempty(x::Union{AbstractConcreteArray,AbstractConcreteNumber}) data = x.data @@ -286,6 +303,12 @@ function Base.show(io::IO, X::Union{ConcretePJRTScalar,ConcreteIFRTScalar}) return nothing end +function Base.summary(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray}) + shape = X.shape + shape_string = join(map(x -> x < 0 ? "?" : string(x), shape), "×") + print(io, "$(shape_string) $(typeof(X))") +end + function Base.print_array(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray}) if isempty(X) print(io, "") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 5ffc60c165..9b265c7fab 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -247,7 +247,24 @@ end Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(getindex, x), length(x)) -Base.size(x::TracedRArray) = x.shape +struct DynamicDimensionSize{SZ} <: Integer + size::SZ +end + +# collect(Int, ....) +function Base.convert(::Type{Int64}, x::DynamicDimensionSize) + return Reactant.MLIR.IR.get_dynamic_size() +end +Base.Int64(x::DynamicDimensionSize) = convert(Int64, x) + +Base.size(x::TracedRArray) = ntuple(i -> size(x, i), ndims(x)) +function Base.size(x::TracedRArray, dim::Integer) + @assert 1 <= dim <= ndims(x) "dimension out of range" + if x.shape[dim] < 0 # assume dynamic size + return DynamicDimensionSize(@opcall get_dimension_size(x, dim)) + end + return x.shape[dim] +end Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct? @@ -1109,7 +1126,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector) return accumulate!(op, A, B; dims=1) end -if isdefined(Base, :_accumulate_promote_op) +@static if isdefined(Base, :_accumulate_promote_op) function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T} if init !== nothing init isa TracedRNumber && (init = zero(unwrapped_eltype(init))) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 260e81e714..85c65bc317 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -480,7 +480,7 @@ function prepare_mlir_fn_args( if toscalar in_tys[i] = MLIR.IR.TensorType(Int[], elT) else - sz = collect(Int, size(arg)) + sz = collect(Int, size(inv_map[arg])) if !optimize_then_pad carg = inv_map[arg] Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg)) diff --git a/src/Types.jl b/src/Types.jl index 6d932ecc0c..a34239b88a 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -72,7 +72,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} - shape = Tuple(shape) + shape = Tuple(collect(Int, shape)) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" end diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index fd4519663d..80d020833d 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -460,6 +460,13 @@ Base.@nospecializeinfer function TensorType( ) end +""" + get_dynamic_size() + +Returns the value indicating a dynamic size in a shaped type. +""" +get_dynamic_size() = API.mlirShapedTypeGetDynamicSize() + """ TensorType(elementType) From bfc0db9626e4a0e9112ce2e40736a32bab169c5b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 17:08:34 -0500 Subject: [PATCH 2/3] Update src/ConcreteRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/ConcreteRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 131b886532..421aa15276 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -306,7 +306,7 @@ end function Base.summary(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray}) shape = X.shape shape_string = join(map(x -> x < 0 ? "?" : string(x), shape), "×") - print(io, "$(shape_string) $(typeof(X))") + return print(io, "$(shape_string) $(typeof(X))") end function Base.print_array(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray}) From 496bab3572e4cb1c3748b15156362895f0e5123d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 19:27:16 -0500 Subject: [PATCH 3/3] fix: more direct usage of size --- ext/ReactantNNlibExt/Implementations.jl | 6 ++-- src/Compiler.jl | 1 + src/Ops.jl | 44 ++++++++++++++++--------- src/TracedRArray.jl | 14 ++------ src/TracedUtils.jl | 14 ++++++-- src/Types.jl | 2 +- 6 files changed, 48 insertions(+), 33 deletions(-) diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 4dda8652c7..644896ea3f 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -52,7 +52,7 @@ function overloaded_conv!( end result = @opcall convolution( - collect(Int64, size(y)), + TracedUtils.collect_dynamic_size(y), x, weight; window_strides=collect(Int64, NNlib.stride(cdims)), @@ -113,7 +113,7 @@ function overloaded_∇conv_filter!( padding = reshape(padding, 2, :) result = @opcall convolution( - collect(Int64, size(dw)), + TracedUtils.collect_dynamic_size(dw), x, dy; window_strides=collect(Int64, NNlib.dilation(cdims)), @@ -207,7 +207,7 @@ function overloaded_∇conv_data!( end result = @opcall convolution( - collect(Int64, size(dx)), + TracedUtils.collect_dynamic_size(dx), dy, w; input_batch_dim=N, diff --git a/src/Compiler.jl b/src/Compiler.jl index bea2f19472..69e43a038c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -2047,6 +2047,7 @@ function compile_mlir!( for (i, arg) in enumerate(linear_args) if haskey(padded_inputs, arg) push!(input_arg_padded_idxs, i) + # TODO: use inverse_arg?? in_tys_padded[i] = MLIR.IR.TensorType( collect(Int, reverse(size(arg) .+ padded_inputs[arg])), MLIR.IR.Type(Reactant.unwrapped_eltype(arg)), diff --git a/src/Ops.jl b/src/Ops.jl index 22bf67679b..1e961cc5a3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -85,14 +85,18 @@ macro opcall(expr) end function mlir_type(x::Union{RNumber,RArray})::MLIR.IR.Type - return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) + return MLIR.IR.TensorType( + Reactant.TracedUtils.collect_dynamic_size(x), MLIR.IR.Type(unwrapped_eltype(x)) + ) end mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) function mlir_type(RT::Type{<:RArray{T,N}}, shape) where {T,N} @assert length(shape) == N - return MLIR.IR.TensorType(collect(Int, shape), MLIR.IR.Type(unwrapped_eltype(RT))) + return MLIR.IR.TensorType( + Reactant.TracedUtils.collect_dynamic_size(shape), MLIR.IR.Type(unwrapped_eltype(RT)) + ) end function mlir_type(RT::Type{<:RNumber})::MLIR.IR.Type @@ -589,7 +593,7 @@ end location=mlir_stacktrace("transpose", @__FILE__, @__LINE__), ) where {T,N} @assert length(permutation) == ndims(x) - rsize = permute!(collect(Int64, size(x)), permutation) + rsize = permute!(Reactant.TracedUtils.collect_dynamic_size(x), permutation) permutation = permutation .- 1 result = mlir_type(TracedRArray{T,N}, rsize) permutation = MLIR.IR.DenseArrayAttribute(permutation) @@ -774,7 +778,7 @@ end elseif type == "RFFT" @assert T <: Real Tout = Complex{T} - rsize = let rsize = collect(Int64, size(x)) + rsize = let rsize = Reactant.TracedUtils.collect_dynamic_size(x) rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1 Tuple(rsize) end @@ -783,7 +787,7 @@ end x = complex(x, fill(T(0), size(x); location); location) end Tout = Base.real(T) - rsize = let rsize = collect(Int64, size(x)) + rsize = let rsize = Reactant.TracedUtils.collect_dynamic_size(x) rsize[(end - Base.length(length) + 1):end] = length Tuple(rsize) end @@ -1314,6 +1318,7 @@ end String(MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))) ) + # TODO: we might need dynamic_iota iota_arg = iota(Int32, collect(Int64, size(x)); iota_dimension=dimension, location) init_arg = constant(Int32(-1); location) init_val = constant(init_val; location) @@ -1329,7 +1334,7 @@ end ) fallback && (backend_config["is_fallback"] = MLIR.IR.Attribute(true)) - result_shape = collect(Int64, size(x)) + result_shape = Reactant.TracedUtils.collect_dynamic_size(x) result_shape[dimension] = k out = stablehlo.custom_call( @@ -1405,7 +1410,9 @@ end ) where {N} return reduce( TracedRArray[ - x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location) + # TODO: dynamic_iota + x, + iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location), ], TracedRNumber[ Reactant.promote_to(TracedRNumber, false), @@ -1428,7 +1435,9 @@ end ) where {T,N} values, indices = reduce( TracedRArray[ - x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location) + # TODO: dynamic_iota + x, + iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location), ], TracedRNumber[ Reactant.promote_to(TracedRNumber{T}, typemin(T)), @@ -1441,7 +1450,7 @@ end end; location, ) - new_shape = collect(Int64, size(x)) + new_shape = Reactant.TracedUtils.collect_dynamic_size(x) new_shape[dimension] = 1 return (reshape(values, new_shape; location), reshape(indices, new_shape; location)) end @@ -3007,7 +3016,9 @@ end ) where {F} @assert allequal(size.(xs)) "All input arrays must have the same size." - reduced_shape = Tuple(deleteat!(collect(Int64, size(xs[1])), dimensions)) + reduced_shape = Tuple( + deleteat!(Reactant.TracedUtils.collect_dynamic_size(xs[1]), dimensions) + ) op = stablehlo.reduce( [x.mlir_data for x in xs], @@ -3112,7 +3123,10 @@ end if ndims(res) != length(permutation) res = reshape( res, - vcat(collect(Int64, size(res)), ones(Int64, length(permutation) - ndims(res))), + vcat( + Reactant.TracedUtils.collect_dynamic_size(res), + ones(Int64, length(permutation) - ndims(res)), + ), ) end return transpose(res, invperm(permutation); location) @@ -3158,7 +3172,7 @@ end bcasted_arg = broadcast_in_dim( v, collect(Int64, (length(batch_shape) + 1):(ndims(v) + length(batch_shape))), - vcat(batch_shape, collect(Int64, size(v))); + vcat(batch_shape, Reactant.TracedUtils.collect_dynamic_size(v)); location, ) push!(final_inputs, bcasted_arg) @@ -3173,7 +3187,7 @@ end push!( output_types, MLIR.IR.TensorType( - vcat(batch_shape, collect(Int64, size(result))), + vcat(batch_shape, Reactant.TracedUtils.collect_dynamic_size(result)), MLIR.IR.Type(unwrapped_eltype(result)), ), ) @@ -3285,7 +3299,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors ` ) where {T,pT} @assert ndims(x) >= 2 - output_shape = collect(Int64, size(x)) + output_shape = Reactant.TracedUtils.collect_dynamic_size(x) batch_shape = output_shape[1:(end - 2)] pivots_shape = vcat(batch_shape, min(size(x, ndims(x) - 1), size(x, ndims(x)))) permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1)) @@ -3531,7 +3545,7 @@ end $(size(input, dimension)) (got $(lhs))" @assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \ $(size(input, dimension)) (got $(rhs))" - sz = collect(Int64, size(input)) + sz = Reactant.TracedUtils.collect_dynamic_size(input) sz[dimension] = sz[dimension] + lhs + rhs return TracedRArray{T,N}( (), diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9b265c7fab..21fb9c65e8 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -247,21 +247,11 @@ end Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(getindex, x), length(x)) -struct DynamicDimensionSize{SZ} <: Integer - size::SZ -end - -# collect(Int, ....) -function Base.convert(::Type{Int64}, x::DynamicDimensionSize) - return Reactant.MLIR.IR.get_dynamic_size() -end -Base.Int64(x::DynamicDimensionSize) = convert(Int64, x) - Base.size(x::TracedRArray) = ntuple(i -> size(x, i), ndims(x)) function Base.size(x::TracedRArray, dim::Integer) @assert 1 <= dim <= ndims(x) "dimension out of range" if x.shape[dim] < 0 # assume dynamic size - return DynamicDimensionSize(@opcall get_dimension_size(x, dim)) + return @opcall get_dimension_size(x, dim) end return x.shape[dim] end @@ -1218,7 +1208,7 @@ function scan_impl!( window_dilations=ones(Int64, N), padding_low=padding_low, padding_high=zeros(Int64, N), - output_shape=collect(Int64, size(output)), + output_shape=TracedUtils.collect_dynamic_size(output), ) )[1] copyto!(output, reduction_result) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 85c65bc317..ec045a5ccd 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -74,6 +74,15 @@ function ReactantCore.materialize_traced_array(x::AbstractArray{TracedRNumber{T} return ReactantCore.materialize_traced_array(as) end +collect_dynamic_size(x::AnyTracedRArray) = collect_dynamic_size(size(x)) +collect_dynamic_size(x::TracedRNumber) = collect_dynamic_size(size(x)) +function collect_dynamic_size(tup::Union{Tuple,Vector}) + return Int64[ + sz isa TracedRNumber ? Reactant.MLIR.IR.get_dynamic_size() : convert(Int64, sz) for + sz in tup + ] +end + get_mlir_data(x::TracedRNumber) = x.mlir_data set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) get_paths(x::TracedRNumber) = x.paths @@ -1178,7 +1187,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} out_tys2 = MLIR.IR.Type[ MLIR.IR.TensorType( - collect(Int, OutShape), MLIR.IR.Type(Reactant.unwrapped_eltype(arg)) + TracedUtils.collect_dynamic_size(OutShape), + MLIR.IR.Type(Reactant.unwrapped_eltype(arg)), ) for arg in linear_results ] @@ -1203,7 +1213,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} batch_inputs; outputs=out_tys2, fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), + batch_shape=MLIR.IR.DenseArrayAttribute(TracedUtils.collect_dynamic_size(OutShape)), ) residx = 1 diff --git a/src/Types.jl b/src/Types.jl index a34239b88a..c643f3df3f 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -72,7 +72,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} - shape = Tuple(collect(Int, shape)) + shape = Tuple(TracedUtils.collect_dynamic_size(shape)) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" end