Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ext/ReactantNNlibExt/Implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function overloaded_conv!(
end

result = @opcall convolution(
collect(Int64, size(y)),
TracedUtils.collect_dynamic_size(y),
x,
weight;
window_strides=collect(Int64, NNlib.stride(cdims)),
Expand Down Expand Up @@ -113,7 +113,7 @@ function overloaded_∇conv_filter!(
padding = reshape(padding, 2, :)

result = @opcall convolution(
collect(Int64, size(dw)),
TracedUtils.collect_dynamic_size(dw),
x,
dy;
window_strides=collect(Int64, NNlib.dilation(cdims)),
Expand Down Expand Up @@ -207,7 +207,7 @@ function overloaded_∇conv_data!(
end

result = @opcall convolution(
collect(Int64, size(dx)),
TracedUtils.collect_dynamic_size(dx),
dy,
w;
input_batch_dim=N,
Expand Down
1 change: 1 addition & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2047,6 +2047,7 @@ function compile_mlir!(
for (i, arg) in enumerate(linear_args)
if haskey(padded_inputs, arg)
push!(input_arg_padded_idxs, i)
# TODO: use inverse_arg??
in_tys_padded[i] = MLIR.IR.TensorType(
collect(Int, reverse(size(arg) .+ padded_inputs[arg])),
MLIR.IR.Type(Reactant.unwrapped_eltype(arg)),
Expand Down
25 changes: 24 additions & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,24 @@ end

Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:AbstractConcreteArray} = T(x)

Base.size(x::AbstractConcreteArray) = x.shape
function Base.size(x::AbstractConcreteArray)
# TODO: inside compile make it dynamic??
# return map(Base.Fix1(_size_static_or_dynamic, x), 1:ndims(x))
return x.shape
end

function _buffer_size(x::AbstractConcreteArray, dim::Integer)
sz = x.shape[dim]
return sz < 0 ? _size_dynamic(x, dim) : sz
end

function _size_dynamic(x::ConcretePJRTArray, dim::Integer)
@assert !Sharding.is_sharded(x.sharding) "TODO: support the sharded case. Use IFRT for \
now."
return size(x.data[1])[ndims(x) - dim + 1]
end

_size_dynamic(x::ConcreteIFRTArray, dim::Integer) = size(x.data)[ndims(x) - dim + 1]

function Base.isempty(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
data = x.data
Expand Down Expand Up @@ -286,6 +303,12 @@ function Base.show(io::IO, X::Union{ConcretePJRTScalar,ConcreteIFRTScalar})
return nothing
end

function Base.summary(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray})
shape = X.shape
shape_string = join(map(x -> x < 0 ? "?" : string(x), shape), "×")
return print(io, "$(shape_string) $(typeof(X))")
end

function Base.print_array(io::IO, X::Union{AnyConcretePJRTArray,AnyConcreteIFRTArray})
if isempty(X)
print(io, "<Empty Buffer eltype $(eltype(X)) of size $(size(X))>")
Expand Down
44 changes: 29 additions & 15 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ macro opcall(expr)
end

function mlir_type(x::Union{RNumber,RArray})::MLIR.IR.Type
return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x)))
return MLIR.IR.TensorType(
Reactant.TracedUtils.collect_dynamic_size(x), MLIR.IR.Type(unwrapped_eltype(x))
)
end

mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool))

function mlir_type(RT::Type{<:RArray{T,N}}, shape) where {T,N}
@assert length(shape) == N
return MLIR.IR.TensorType(collect(Int, shape), MLIR.IR.Type(unwrapped_eltype(RT)))
return MLIR.IR.TensorType(
Reactant.TracedUtils.collect_dynamic_size(shape), MLIR.IR.Type(unwrapped_eltype(RT))
)
end

function mlir_type(RT::Type{<:RNumber})::MLIR.IR.Type
Expand Down Expand Up @@ -589,7 +593,7 @@ end
location=mlir_stacktrace("transpose", @__FILE__, @__LINE__),
) where {T,N}
@assert length(permutation) == ndims(x)
rsize = permute!(collect(Int64, size(x)), permutation)
rsize = permute!(Reactant.TracedUtils.collect_dynamic_size(x), permutation)
permutation = permutation .- 1
result = mlir_type(TracedRArray{T,N}, rsize)
permutation = MLIR.IR.DenseArrayAttribute(permutation)
Expand Down Expand Up @@ -774,7 +778,7 @@ end
elseif type == "RFFT"
@assert T <: Real
Tout = Complex{T}
rsize = let rsize = collect(Int64, size(x))
rsize = let rsize = Reactant.TracedUtils.collect_dynamic_size(x)
rsize[end] = rsize[end] == 0 ? 0 : rsize[end] ÷ 2 + 1
Tuple(rsize)
end
Expand All @@ -783,7 +787,7 @@ end
x = complex(x, fill(T(0), size(x); location); location)
end
Tout = Base.real(T)
rsize = let rsize = collect(Int64, size(x))
rsize = let rsize = Reactant.TracedUtils.collect_dynamic_size(x)
rsize[(end - Base.length(length) + 1):end] = length
Tuple(rsize)
end
Expand Down Expand Up @@ -1314,6 +1318,7 @@ end
String(MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())))
)

# TODO: we might need dynamic_iota
iota_arg = iota(Int32, collect(Int64, size(x)); iota_dimension=dimension, location)
init_arg = constant(Int32(-1); location)
init_val = constant(init_val; location)
Expand All @@ -1329,7 +1334,7 @@ end
)
fallback && (backend_config["is_fallback"] = MLIR.IR.Attribute(true))

result_shape = collect(Int64, size(x))
result_shape = Reactant.TracedUtils.collect_dynamic_size(x)
result_shape[dimension] = k

out = stablehlo.custom_call(
Expand Down Expand Up @@ -1405,7 +1410,9 @@ end
) where {N}
return reduce(
TracedRArray[
x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location)
# TODO: dynamic_iota
x,
iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location),
],
TracedRNumber[
Reactant.promote_to(TracedRNumber, false),
Expand All @@ -1428,7 +1435,9 @@ end
) where {T,N}
values, indices = reduce(
TracedRArray[
x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location)
# TODO: dynamic_iota
x,
iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location),
],
TracedRNumber[
Reactant.promote_to(TracedRNumber{T}, typemin(T)),
Expand All @@ -1441,7 +1450,7 @@ end
end;
location,
)
new_shape = collect(Int64, size(x))
new_shape = Reactant.TracedUtils.collect_dynamic_size(x)
new_shape[dimension] = 1
return (reshape(values, new_shape; location), reshape(indices, new_shape; location))
end
Expand Down Expand Up @@ -3007,7 +3016,9 @@ end
) where {F}
@assert allequal(size.(xs)) "All input arrays must have the same size."

reduced_shape = Tuple(deleteat!(collect(Int64, size(xs[1])), dimensions))
reduced_shape = Tuple(
deleteat!(Reactant.TracedUtils.collect_dynamic_size(xs[1]), dimensions)
)

op = stablehlo.reduce(
[x.mlir_data for x in xs],
Expand Down Expand Up @@ -3112,7 +3123,10 @@ end
if ndims(res) != length(permutation)
res = reshape(
res,
vcat(collect(Int64, size(res)), ones(Int64, length(permutation) - ndims(res))),
vcat(
Reactant.TracedUtils.collect_dynamic_size(res),
ones(Int64, length(permutation) - ndims(res)),
),
)
end
return transpose(res, invperm(permutation); location)
Expand Down Expand Up @@ -3158,7 +3172,7 @@ end
bcasted_arg = broadcast_in_dim(
v,
collect(Int64, (length(batch_shape) + 1):(ndims(v) + length(batch_shape))),
vcat(batch_shape, collect(Int64, size(v)));
vcat(batch_shape, Reactant.TracedUtils.collect_dynamic_size(v));
location,
)
push!(final_inputs, bcasted_arg)
Expand All @@ -3173,7 +3187,7 @@ end
push!(
output_types,
MLIR.IR.TensorType(
vcat(batch_shape, collect(Int64, size(result))),
vcat(batch_shape, Reactant.TracedUtils.collect_dynamic_size(result)),
MLIR.IR.Type(unwrapped_eltype(result)),
),
)
Expand Down Expand Up @@ -3285,7 +3299,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
) where {T,pT}
@assert ndims(x) >= 2

output_shape = collect(Int64, size(x))
output_shape = Reactant.TracedUtils.collect_dynamic_size(x)
batch_shape = output_shape[1:(end - 2)]
pivots_shape = vcat(batch_shape, min(size(x, ndims(x) - 1), size(x, ndims(x))))
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
Expand Down Expand Up @@ -3531,7 +3545,7 @@ end
$(size(input, dimension)) (got $(lhs))"
@assert 0 ≤ rhs ≤ size(input, dimension) "rhs must be between 0 and \
$(size(input, dimension)) (got $(rhs))"
sz = collect(Int64, size(input))
sz = Reactant.TracedUtils.collect_dynamic_size(input)
sz[dimension] = sz[dimension] + lhs + rhs
return TracedRArray{T,N}(
(),
Expand Down
13 changes: 10 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,14 @@ end

Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(getindex, x), length(x))

Base.size(x::TracedRArray) = x.shape
Base.size(x::TracedRArray) = ntuple(i -> size(x, i), ndims(x))
function Base.size(x::TracedRArray, dim::Integer)
@assert 1 <= dim <= ndims(x) "dimension out of range"
if x.shape[dim] < 0 # assume dynamic size
return @opcall get_dimension_size(x, dim)
end
return x.shape[dim]
end

Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct?

Expand Down Expand Up @@ -1109,7 +1116,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
return accumulate!(op, A, B; dims=1)
end

if isdefined(Base, :_accumulate_promote_op)
@static if isdefined(Base, :_accumulate_promote_op)
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
if init !== nothing
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
Expand Down Expand Up @@ -1201,7 +1208,7 @@ function scan_impl!(
window_dilations=ones(Int64, N),
padding_low=padding_low,
padding_high=zeros(Int64, N),
output_shape=collect(Int64, size(output)),
output_shape=TracedUtils.collect_dynamic_size(output),
)
)[1]
copyto!(output, reduction_result)
Expand Down
16 changes: 13 additions & 3 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ function ReactantCore.materialize_traced_array(x::AbstractArray{TracedRNumber{T}
return ReactantCore.materialize_traced_array(as)
end

collect_dynamic_size(x::AnyTracedRArray) = collect_dynamic_size(size(x))
collect_dynamic_size(x::TracedRNumber) = collect_dynamic_size(size(x))
function collect_dynamic_size(tup::Union{Tuple,Vector})
return Int64[
sz isa TracedRNumber ? Reactant.MLIR.IR.get_dynamic_size() : convert(Int64, sz) for
sz in tup
]
end

get_mlir_data(x::TracedRNumber) = x.mlir_data
set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x)
get_paths(x::TracedRNumber) = x.paths
Expand Down Expand Up @@ -480,7 +489,7 @@ function prepare_mlir_fn_args(
if toscalar
in_tys[i] = MLIR.IR.TensorType(Int[], elT)
else
sz = collect(Int, size(arg))
sz = collect(Int, size(inv_map[arg]))
if !optimize_then_pad
carg = inv_map[arg]
Reactant.has_padding(carg) && (sz .+= Reactant.get_padding(carg))
Expand Down Expand Up @@ -1178,7 +1187,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}

out_tys2 = MLIR.IR.Type[
MLIR.IR.TensorType(
collect(Int, OutShape), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))
TracedUtils.collect_dynamic_size(OutShape),
MLIR.IR.Type(Reactant.unwrapped_eltype(arg)),
) for arg in linear_results
]

Expand All @@ -1203,7 +1213,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
batch_inputs;
outputs=out_tys2,
fn=fname,
batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]),
batch_shape=MLIR.IR.DenseArrayAttribute(TracedUtils.collect_dynamic_size(OutShape)),
)

residx = 1
Expand Down
2 changes: 1 addition & 1 deletion src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
function TracedRArray{T,N}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape
) where {T,N}
shape = Tuple(shape)
shape = Tuple(TracedUtils.collect_dynamic_size(shape))
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
end
Expand Down
7 changes: 7 additions & 0 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ Base.@nospecializeinfer function TensorType(
)
end

"""
get_dynamic_size()

Returns the value indicating a dynamic size in a shaped type.
"""
get_dynamic_size() = API.mlirShapedTypeGetDynamicSize()

"""
TensorType(elementType)

Expand Down
Loading