From 4a1be107d237256be43e380317a0dce1edd77653 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 2 Oct 2024 17:31:44 -0400 Subject: [PATCH 01/34] feat: TracedRScalar --- src/Reactant.jl | 1 + src/TracedRArray.jl | 74 ++++++++++++++++++++++++++++++++++++--------- src/utils.jl | 6 ++++ 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 92b4c9cdfd..affdea4eec 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -8,6 +8,7 @@ include("OrderedIdDict.jl") using Enzyme abstract type RArray{T,N} <: AbstractArray{T,N} end +abstract type RScalar{T} <: Number end function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9e46f567e4..12e7948467 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,9 +19,22 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x +mutable struct TracedRScalar{T} <: RScalar{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRScalar{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRScalar{T} = AnyTracedRArray{T,0} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} @@ -40,12 +53,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -Base.getindex(a::AnyTracedRScalar{T}) where {T} = a +Base.getindex(a::TracedRScalar{T}) where {T} = a -Base.zero(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, zero(T)) -Base.one(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, one(T)) +Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T)) +Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T)) -function Base.convert(::Type{<:AnyTracedRScalar{T}}, x::Number) where {T} +function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T} return promote_to(TracedRArray{T,0}, T(x)) end @@ -73,7 +86,7 @@ and require expensive copies and synchronization each time and therefore should ), 1, ) - return TracedRArray{T,0}((), res2, ()) + return TracedRScalar{T}((), res2) end function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} @@ -133,7 +146,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -Base.only(A::AnyTracedRScalar{T}) where {T} = A +function Base.show(io::IOty, X::TracedRScalar{T}) where {T,IOty<:Union{IO,IOContext}} + return print(io, "TracedRScalar{", T, "}(", X.paths, ")") +end + +Base.only(A::TracedRScalar{T}) where {T} = A function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} if prod(dims) != prod(size(A)) @@ -207,9 +224,7 @@ end function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} if isa(rhs, TracedRArray) - if typeof(rhs) == TracedRArray{T,N} - return rhs - end + rhs isa TracedRArray{T,N} && return rhs return TracedRArray{T,N}( (), MLIR.IR.result( @@ -222,11 +237,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs))) - ta = TracedRArray{T,N}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs) - ) - return ta + throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \ + `TracedRScalar` instead.")) end T0 = eltype(rhs) attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) @@ -238,9 +250,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end +function promote_to(::Type{TracedRScalar{T}}, rhs) where {T} + if isa(rhs, TracedRScalar) + rhs isa TracedRScalar{T} && return rhs + return TracedRScalar{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.convert( + rhs.mlir_data; result=mlir_type(TracedRScalar{T}) + ), + 1, + ), + ) + end + if isa(rhs, Number) + attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRScalar{T})) + return TracedRScalar{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ) + end + T0 = eltype(rhs) + attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) + return promote_to( + TracedRScalar{T}, + TracedRScalar{T0}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ), + ) +end + function promote_to(::TracedRArray{T,N}, rhs) where {T,N} return promote_to(TracedRArray{T,N}, rhs) end +function promote_to(::TracedRScalar{T}, rhs) where {T} + return promote_to(TracedRScalar{T}, rhs) +end for (jlop, hloop) in ( (:(Base.min), :minimum), diff --git a/src/utils.jl b/src/utils.jl index 4e137a3889..950792d4d7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,11 +2,17 @@ function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) end +mlir_type(::RScalar{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) + function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} @assert length(shape) == N return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) end +function mlir_type(::Type{<:RScalar{T}}) where {T} + return MLIR.IR.TensorType((), MLIR.IR.Type(T)) +end + function transpose_ty(mlirty) return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) end From 978fdd9d83084dbf8af870703d26445705570003 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 2 Oct 2024 22:07:29 -0400 Subject: [PATCH 02/34] feat: partial progress on getting scalars to work --- ext/ReactantNNlibExt.jl | 15 ++-- src/TracedRArray.jl | 155 ++++++++++++++++------------------------ src/Tracing.jl | 41 ++++++++++- src/utils.jl | 8 +-- 4 files changed, 111 insertions(+), 108 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index bac3ee75ff..f4830406ac 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -1,31 +1,24 @@ module ReactantNNlibExt using NNlib -using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR +using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, + TracedRScalar for (jlop, hloop) in ( (:(NNlib.tanh_fast), :tanh), (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval function $(jlop)(x::TracedRArray{T,0}) where {T} - return TracedRArray{T,0}( + @eval function $(jlop)(x::TracedRScalar{T}) where {T} + return TracedRScalar{T}( (), Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 ), - (), ) end end -# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays -for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, :σ)) - @eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T} - return invoke(NNlib.$(nnlib_op), Tuple{Any}, x) - end -end - # TODO handle non finite cases function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} max_ = NNlib.fast_maximum(x; dims) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 12e7948467..72f1d5151a 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,6 +19,13 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x +function Base.setproperty!(x::TracedRArray, f::Symbol, v) + if f === :mlir_data && !isnothing(v) + @assert size(MLIR.IR.type(v)) == size(x) + end + return setfield!(x, f, v) +end + mutable struct TracedRScalar{T} <: RScalar{T} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} @@ -33,6 +40,15 @@ mutable struct TracedRScalar{T} <: RScalar{T} end end +function Base.setproperty!(x::TracedRScalar, f::Symbol, v) + if f === :mlir_data && !isnothing(v) + @assert size(MLIR.IR.type(v)) == () + end + return setfield!(x, f, v) +end + +Base.eltype(::Type{TracedRScalar{T}}) where {T} = T + const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} @@ -59,7 +75,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T)) Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T)) function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T} - return promote_to(TracedRArray{T,0}, T(x)) + return promote_to(TracedRScalar{T}, T(x)) end function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} @@ -121,7 +137,7 @@ function Base.setindex!( a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N} ) where {T,N} indices = [ - (promote_to(TracedRArray{Int,0}, i isa Colon ? 1 : first(i)) - 1).mlir_data for + (promote_to(TracedRScalar{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for i in indices ] v = promote_to(TracedRArray{T,N}, v) @@ -222,6 +238,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} return TracedRArray{Base.promote_type(T, S),N} end +function Base.promote_rule(::Type{T}, ::Type{TracedRScalar{S}}) where {T,S} + return TracedRScalar{Base.promote_type(T, S)} +end + +function Base.convert(::Type{TracedRScalar{T}}, x::Number) where {T} + return promote_to(TracedRScalar{T}, x) +end + function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} if isa(rhs, TracedRArray) rhs isa TracedRArray{T,N} && return rhs @@ -279,12 +303,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T} ) end -function promote_to(::TracedRArray{T,N}, rhs) where {T,N} - return promote_to(TracedRArray{T,N}, rhs) -end -function promote_to(::TracedRScalar{T}, rhs) where {T} - return promote_to(TracedRScalar{T}, rhs) -end +promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) +promote_to(::TracedRScalar{T}, rhs) where {T} = promote_to(TracedRScalar{T}, rhs) for (jlop, hloop) in ( (:(Base.min), :minimum), @@ -295,66 +315,35 @@ for (jlop, hloop) in ( (:(Base.:/), :divide), (:(Base.:^), :power), ) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return TracedRArray{T,0}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 - ), - (), - ) - end - - function $(jlop)( - @nospecialize(lhs::TracedRArray{T1,0}), @nospecialize(rhs::TracedRArray{T2,0}) - ) where {T1,T2} - commonTy = TracedRArray{Base.promote_type(T1, T2),0} - lhs = promote_to(commonTy, lhs) - rhs = promote_to(commonTy, rhs) - return $(jlop)(lhs, rhs) - end - end - - for otherType in (Number, Any) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::$(otherType)) - ) where {T} - rhs = promote_to(lhs, rhs) - return $(jlop)(lhs, rhs) - end - - function $(jlop)( - @nospecialize(lhs::$(otherType)), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - lhs = promote_to(rhs, lhs) - return $(jlop)(lhs, rhs) - end - end + @eval function $(jlop)( + @nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T}) + ) where {T} + return TracedRArray{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) end end function Base.ifelse( - @nospecialize(pred::TracedRArray{Bool,0}), - @nospecialize(x::TracedRArray{T1,0}), - @nospecialize(y::TracedRArray{T2,0}) + @nospecialize(pred::TracedRScalar{Bool}), + @nospecialize(x::TracedRScalar{T1}), + @nospecialize(y::TracedRScalar{T2}) ) where {T1,T2} - return TracedRArray{promote_type(T1, T2),0}( + return TracedRScalar{promote_type(T1, T2)}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 ), - size(pred), ) end -Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x) +Base.abs2(x::Reactant.TracedRScalar{T}) where {T} = x * conj(x) function Base.literal_pow( - ::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}} + ::Base.RefValue{typeof(^)}, x::TracedRScalar{T}, ::Base.RefValue{Val{P}} ) where {T,P} return Base.literal_pow(^, x, Val(P)) end @@ -371,14 +360,10 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.sqrt), :sqrt), ) - @eval begin - function $jlop(@nospecialize(lhs::TracedRArray{T,0})) where {T} - return TracedRArray{T,0}( - (), - MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1), - size(lhs), - ) - end + @eval function $(jlop)(@nospecialize(lhs::TracedRScalar{T})) where {T} + return TracedRScalar{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) + ) end end @@ -445,6 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} residx = 1 for a in linear_results + @show a if has_residx(a) path = get_residx(a) set!(result, path[2:end], MLIR.IR.result(res, residx)) @@ -480,37 +466,22 @@ for (jlop, hloop, hlocomp, merge) in ( (:(Base.:(<=)), :compare, "LE", nothing), (:(Base.:(<)), :compare, "LT", nothing), ) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return TracedRArray{Bool,0}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp - ), + @eval function $(jlop)( + @nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T}) + ) where {T} + return TracedRScalar{Bool}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp ), - 1, ), - size(lhs), - ) - end - - function $(jlop)( - @nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs) - ) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) - end - - function $(jlop)( - @nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0}) - ) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) - end + 1, + ), + ) end if merge !== nothing @@ -600,7 +571,7 @@ function Base.mapreduce( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) args = ( - TracedRArray{T,0}((), MLIR.IR.argument(fnbody, i), ()) for + TracedRScalar{T}((), MLIR.IR.argument(fnbody, i), ()) for (i, ty) in enumerate(in_tys) ) diff --git a/src/Tracing.jl b/src/Tracing.jl index ae4f3b4c66..0387f21909 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -16,6 +16,7 @@ for T in ( Integer, AbstractString, RArray, + RScalar, ) @eval function traced_type(::Type{T}, seen, mode) where {T<:$T} return T @@ -330,7 +331,7 @@ function make_tracer( return seen[prev] end res = if toscalar - TracedRArray{T,0}((path,), nothing, ()) + TracedRScalar{T}((path,), nothing) elseif !isnothing(tobatch) TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) else @@ -352,6 +353,44 @@ function make_tracer( throw("Cannot Unknown trace mode $mode") end +function make_tracer( + seen, + @nospecialize(prev::TracedRScalar{T}), + @nospecialize(path), + mode; + kwargs... +) where {T} + if mode == ConcreteToTraced + throw("Cannot trace existing trace type") + end + if mode == TracedTrack + prev.paths = (prev.paths..., path) + if !haskey(seen, prev) + return seen[prev] = prev + end + return prev + end + if mode == TracedSetPath + if haskey(seen, prev) + return seen[prev] + end + res = TracedRScalar{T}((path,), prev.mlir_data) + seen[prev] = res + return res + end + + if mode == TracedToConcrete + if haskey(seen, prev) + return seen[prev]::ConcreteRArray{T,0} + end + res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev)) + seen[prev] = res + return res + end + + throw("Cannot Unknown trace mode $mode") +end + function make_tracer( seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... ) where {RT<:AbstractFloat} diff --git a/src/utils.jl b/src/utils.jl index 950792d4d7..7a3481add5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,9 +44,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_args = TracedRArray[] + linear_args = Union{TracedRArray,TracedRScalar}[] for (k, v) in seen_args - if !(v isa TracedRArray) + if !(v isa TracedRArray) && !(v isa TracedRScalar) continue end push!(linear_args, v) @@ -127,10 +127,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_results = TracedRArray[] + linear_results = Union{TracedRArray,TracedRScalar}[] for (k, v) in seen_results - if !(v isa TracedRArray) + if !(v isa TracedRArray) && !(v isa TracedRScalar) continue end From 32af33207c1fce507c2a47109c159c0bb5d17654 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 20:22:35 -0400 Subject: [PATCH 03/34] refactor: Scalar --> Number --- ext/ReactantNNlibExt.jl | 6 +-- src/Reactant.jl | 2 +- src/TracedRArray.jl | 82 ++++++++++++++++++++--------------------- src/Tracing.jl | 8 ++-- src/utils.jl | 12 +++--- 5 files changed, 55 insertions(+), 55 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index f4830406ac..e3bd708598 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -2,15 +2,15 @@ module ReactantNNlibExt using NNlib using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, - TracedRScalar + TracedRNumber for (jlop, hloop) in ( (:(NNlib.tanh_fast), :tanh), (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval function $(jlop)(x::TracedRScalar{T}) where {T} - return TracedRScalar{T}( + @eval function $(jlop)(x::TracedRNumber{T}) where {T} + return TracedRNumber{T}( (), Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 diff --git a/src/Reactant.jl b/src/Reactant.jl index affdea4eec..a0b5866ff7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -8,7 +8,7 @@ include("OrderedIdDict.jl") using Enzyme abstract type RArray{T,N} <: AbstractArray{T,N} end -abstract type RScalar{T} <: Number end +abstract type RNumber{T} <: Number end function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 72f1d5151a..be21328100 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -26,11 +26,11 @@ function Base.setproperty!(x::TracedRArray, f::Symbol, v) return setfield!(x, f, v) end -mutable struct TracedRScalar{T} <: RScalar{T} +mutable struct TracedRNumber{T} <: RNumber{T} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} - function TracedRScalar{T}( + function TracedRNumber{T}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} ) where {T} if !isnothing(mlir_data) @@ -40,14 +40,14 @@ mutable struct TracedRScalar{T} <: RScalar{T} end end -function Base.setproperty!(x::TracedRScalar, f::Symbol, v) +function Base.setproperty!(x::TracedRNumber, f::Symbol, v) if f === :mlir_data && !isnothing(v) @assert size(MLIR.IR.type(v)) == () end return setfield!(x, f, v) end -Base.eltype(::Type{TracedRScalar{T}}) where {T} = T +Base.eltype(::Type{TracedRNumber{T}}) where {T} = T const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} @@ -69,13 +69,13 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -Base.getindex(a::TracedRScalar{T}) where {T} = a +Base.getindex(a::TracedRNumber{T}) where {T} = a -Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T)) -Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T)) +Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) +Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) -function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T} - return promote_to(TracedRScalar{T}, T(x)) +function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, T(x)) end function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} @@ -102,7 +102,7 @@ and require expensive copies and synchronization each time and therefore should ), 1, ) - return TracedRScalar{T}((), res2) + return TracedRNumber{T}((), res2) end function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} @@ -137,7 +137,7 @@ function Base.setindex!( a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N} ) where {T,N} indices = [ - (promote_to(TracedRScalar{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for + (promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for i in indices ] v = promote_to(TracedRArray{T,N}, v) @@ -162,11 +162,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -function Base.show(io::IOty, X::TracedRScalar{T}) where {T,IOty<:Union{IO,IOContext}} - return print(io, "TracedRScalar{", T, "}(", X.paths, ")") +function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} + return print(io, "TracedRNumber{", T, "}(", X.paths, ")") end -Base.only(A::TracedRScalar{T}) where {T} = A +Base.only(A::TracedRNumber{T}) where {T} = A function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} if prod(dims) != prod(size(A)) @@ -238,12 +238,12 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} return TracedRArray{Base.promote_type(T, S),N} end -function Base.promote_rule(::Type{T}, ::Type{TracedRScalar{S}}) where {T,S} - return TracedRScalar{Base.promote_type(T, S)} +function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} end -function Base.convert(::Type{TracedRScalar{T}}, x::Number) where {T} - return promote_to(TracedRScalar{T}, x) +function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, x) end function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} @@ -262,7 +262,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} end if isa(rhs, Number) throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \ - `TracedRScalar` instead.")) + `TracedRNumber` instead.")) end T0 = eltype(rhs) attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) @@ -274,37 +274,37 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end -function promote_to(::Type{TracedRScalar{T}}, rhs) where {T} - if isa(rhs, TracedRScalar) - rhs isa TracedRScalar{T} && return rhs - return TracedRScalar{T}( +function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} + if isa(rhs, TracedRNumber) + rhs isa TracedRNumber{T} && return rhs + return TracedRNumber{T}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRScalar{T}) + rhs.mlir_data; result=mlir_type(TracedRNumber{T}) ), 1, ), ) end if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRScalar{T})) - return TracedRScalar{T}( + attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) + return TracedRNumber{T}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) ) end T0 = eltype(rhs) attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) return promote_to( - TracedRScalar{T}, - TracedRScalar{T0}( + TracedRNumber{T}, + TracedRNumber{T0}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) ), ) end promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -promote_to(::TracedRScalar{T}, rhs) where {T} = promote_to(TracedRScalar{T}, rhs) +promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) for (jlop, hloop) in ( (:(Base.min), :minimum), @@ -316,7 +316,7 @@ for (jlop, hloop) in ( (:(Base.:^), :power), ) @eval function $(jlop)( - @nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T}) + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} return TracedRArray{T}( (), @@ -328,11 +328,11 @@ for (jlop, hloop) in ( end function Base.ifelse( - @nospecialize(pred::TracedRScalar{Bool}), - @nospecialize(x::TracedRScalar{T1}), - @nospecialize(y::TracedRScalar{T2}) + @nospecialize(pred::TracedRNumber{Bool}), + @nospecialize(x::TracedRNumber{T1}), + @nospecialize(y::TracedRNumber{T2}) ) where {T1,T2} - return TracedRScalar{promote_type(T1, T2)}( + return TracedRNumber{promote_type(T1, T2)}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 @@ -340,10 +340,10 @@ function Base.ifelse( ) end -Base.abs2(x::Reactant.TracedRScalar{T}) where {T} = x * conj(x) +Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x) function Base.literal_pow( - ::Base.RefValue{typeof(^)}, x::TracedRScalar{T}, ::Base.RefValue{Val{P}} + ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} ) where {T,P} return Base.literal_pow(^, x, Val(P)) end @@ -360,8 +360,8 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.sqrt), :sqrt), ) - @eval function $(jlop)(@nospecialize(lhs::TracedRScalar{T})) where {T} - return TracedRScalar{T}( + @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} + return TracedRNumber{T}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) ) end @@ -467,9 +467,9 @@ for (jlop, hloop, hlocomp, merge) in ( (:(Base.:(<)), :compare, "LT", nothing), ) @eval function $(jlop)( - @nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T}) + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return TracedRScalar{Bool}( + return TracedRNumber{Bool}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$(hloop)( @@ -571,7 +571,7 @@ function Base.mapreduce( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) args = ( - TracedRScalar{T}((), MLIR.IR.argument(fnbody, i), ()) for + TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for (i, ty) in enumerate(in_tys) ) diff --git a/src/Tracing.jl b/src/Tracing.jl index 0387f21909..2af91ac852 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -16,7 +16,7 @@ for T in ( Integer, AbstractString, RArray, - RScalar, + RNumber, ) @eval function traced_type(::Type{T}, seen, mode) where {T<:$T} return T @@ -331,7 +331,7 @@ function make_tracer( return seen[prev] end res = if toscalar - TracedRScalar{T}((path,), nothing) + TracedRNumber{T}((path,), nothing) elseif !isnothing(tobatch) TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) else @@ -355,7 +355,7 @@ end function make_tracer( seen, - @nospecialize(prev::TracedRScalar{T}), + @nospecialize(prev::TracedRNumber{T}), @nospecialize(path), mode; kwargs... @@ -374,7 +374,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = TracedRScalar{T}((path,), prev.mlir_data) + res = TracedRNumber{T}((path,), prev.mlir_data) seen[prev] = res return res end diff --git a/src/utils.jl b/src/utils.jl index 7a3481add5..dd5f834d56 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,14 +2,14 @@ function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) end -mlir_type(::RScalar{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) +mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} @assert length(shape) == N return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) end -function mlir_type(::Type{<:RScalar{T}}) where {T} +function mlir_type(::Type{<:RNumber{T}}) where {T} return MLIR.IR.TensorType((), MLIR.IR.Type(T)) end @@ -44,9 +44,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_args = Union{TracedRArray,TracedRScalar}[] + linear_args = Union{TracedRArray,TracedRNumber}[] for (k, v) in seen_args - if !(v isa TracedRArray) && !(v isa TracedRScalar) + if !(v isa TracedRArray) && !(v isa TracedRNumber) continue end push!(linear_args, v) @@ -127,10 +127,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_results = Union{TracedRArray,TracedRScalar}[] + linear_results = Union{TracedRArray,TracedRNumber}[] for (k, v) in seen_results - if !(v isa TracedRArray) && !(v isa TracedRScalar) + if !(v isa TracedRArray) && !(v isa TracedRNumber) continue end From 8c213cbcd3c5431505a516abd65d6f7224fe7d4a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 20:46:48 -0400 Subject: [PATCH 04/34] fix: batching --- src/TracedRArray.jl | 17 +---------------- src/Tracing.jl | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index be21328100..327bf719fb 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,13 +19,6 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x -function Base.setproperty!(x::TracedRArray, f::Symbol, v) - if f === :mlir_data && !isnothing(v) - @assert size(MLIR.IR.type(v)) == size(x) - end - return setfield!(x, f, v) -end - mutable struct TracedRNumber{T} <: RNumber{T} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} @@ -40,13 +33,6 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end -function Base.setproperty!(x::TracedRNumber, f::Symbol, v) - if f === :mlir_data && !isnothing(v) - @assert size(MLIR.IR.type(v)) == () - end - return setfield!(x, f, v) -end - Base.eltype(::Type{TracedRNumber{T}}) where {T} = T const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} @@ -318,7 +304,7 @@ for (jlop, hloop) in ( @eval function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return TracedRArray{T}( + return TracedRNumber{T}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 @@ -430,7 +416,6 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} residx = 1 for a in linear_results - @show a if has_residx(a) path = get_residx(a) set!(result, path[2:end], MLIR.IR.result(res, residx)) diff --git a/src/Tracing.jl b/src/Tracing.jl index 2af91ac852..a94f6e28fd 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -332,8 +332,8 @@ function make_tracer( end res = if toscalar TracedRNumber{T}((path,), nothing) - elseif !isnothing(tobatch) - TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + elseif tobatch !== nothing + error("This should not happen...") else TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) end @@ -358,7 +358,9 @@ function make_tracer( @nospecialize(prev::TracedRNumber{T}), @nospecialize(path), mode; - kwargs... + tobatch=nothing, + toscalar=false, + kwargs..., ) where {T} if mode == ConcreteToTraced throw("Cannot trace existing trace type") @@ -374,7 +376,13 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = TracedRNumber{T}((path,), prev.mlir_data) + res = if toscalar + TracedRNumber{T}((path,), nothing) + elseif tobatch !== nothing + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + else + TracedRNumber{T}((path,), prev.mlir_data) + end seen[prev] = res return res end From 49e124a6003b663426af66d7048b71e076ad64ff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:05:54 -0400 Subject: [PATCH 05/34] fix: promote_rule and introduce union over primitive types --- src/TracedRArray.jl | 59 +++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 327bf719fb..e331aa3160 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,7 +19,26 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x -mutable struct TracedRNumber{T} <: RNumber{T} +const ReactantPrimitives = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Float32, + # BFloat16, + Float64, + Complex{Float32}, + Complex{Float64}, +} + +# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s +mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} @@ -214,14 +233,8 @@ function Base.transpose(A::AnyTracedRVecOrMat) end Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A) -function Base.promote_rule( - ::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}} -) where {T,S,N} - return TracedRArray{Base.promote_type(T, S),N} -end - -function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} - return TracedRArray{Base.promote_type(T, S),N} +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} end function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} @@ -326,8 +339,6 @@ function Base.ifelse( ) end -Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x) - function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} ) where {T,P} @@ -355,8 +366,8 @@ end struct TypeCast{T<:Number} <: Function end -function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2} - return promote_to(TracedRArray{T,0}, x) +function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} + return promote_to(TracedRNumber{T}, x) end elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x @@ -556,8 +567,7 @@ function Base.mapreduce( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) args = ( - TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for - (i, ty) in enumerate(in_tys) + TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys) ) res = MLIR.IR.block!(fnbody) do @@ -708,6 +718,25 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number} ) end +function broadcast_to_size(arg::TracedRNumber, rsize) + rsize == () && return arg + mlirty = MLIR.IR.type(arg.mlir_data) + return TracedRArray{eltype(arg),length(rsize)}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.broadcast_in_dim( + arg.mlir_data; + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute([ + Int64(i - 1) for i in rsize + ]), + ), + 1, + ), + rsize, + ) +end + function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) size(arg) == rsize && return arg From 6dc36f130dd0cbcde60da173c1dc8f706e0a753e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:08:09 -0400 Subject: [PATCH 06/34] chore: apply formatting --- ext/ReactantNNlibExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index e3bd708598..347b24e72a 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -1,8 +1,8 @@ module ReactantNNlibExt using NNlib -using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, - TracedRNumber +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber for (jlop, hloop) in ( (:(NNlib.tanh_fast), :tanh), From 20de817b99148ada2b9269f8a034b8636baf8e87 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:13:12 -0400 Subject: [PATCH 07/34] feat: type-restrict arrays --- src/Reactant.jl | 22 ++++++++++++++++++++-- src/TracedRArray.jl | 21 +-------------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index a0b5866ff7..f0fdfc9cb6 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -7,8 +7,26 @@ include("OrderedIdDict.jl") using Enzyme -abstract type RArray{T,N} <: AbstractArray{T,N} end -abstract type RNumber{T} <: Number end +const ReactantPrimitives = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Float32, + # BFloat16, + Float64, + Complex{Float32}, + Complex{Float64}, +} + +abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end +abstract type RNumber{T<:ReactantPrimitives} <: Number end function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e331aa3160..c3e54f080b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,26 +19,7 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x -const ReactantPrimitives = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Float32, - # BFloat16, - Float64, - Complex{Float32}, - Complex{Float64}, -} - -# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s -mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T} +mutable struct TracedRNumber{T} <: RNumber{T} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} From e04e3b62ca0c686fdddde81b2349aec3154cc33b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:19:00 -0400 Subject: [PATCH 08/34] refactor: move scalar ops to a separate file --- src/Reactant.jl | 1 + src/TracedRArray.jl | 178 ++----------------------------------------- src/TracedRNumber.jl | 161 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 171 deletions(-) create mode 100644 src/TracedRNumber.jl diff --git a/src/Reactant.jl b/src/Reactant.jl index f0fdfc9cb6..10155b75c7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -66,6 +66,7 @@ include("Interpreter.jl") include("utils.jl") include("ConcreteRArray.jl") include("TracedRArray.jl") +include("TracedRNumber.jl") include("Tracing.jl") include("Compiler.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index c3e54f080b..f68dd51c2c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,22 +19,6 @@ end TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x -mutable struct TracedRNumber{T} <: RNumber{T} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} - - function TracedRNumber{T}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} - ) where {T} - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == () - end - return new{T}(paths, mlir_data) - end -end - -Base.eltype(::Type{TracedRNumber{T}}) where {T} = T - const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} @@ -55,15 +39,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -Base.getindex(a::TracedRNumber{T}) where {T} = a - -Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) -Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) - -function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, T(x)) -end - function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N} @warn( """Performing scalar indexing on task $(current_task()). @@ -148,12 +123,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} - return print(io, "TracedRNumber{", T, "}(", X.paths, ")") -end - -Base.only(A::TracedRNumber{T}) where {T} = A - function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} if prod(dims) != prod(size(A)) throw( @@ -214,18 +183,6 @@ function Base.transpose(A::AnyTracedRVecOrMat) end Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A) -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} - return TracedRNumber{Base.promote_type(T, S)} -end - -function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} - return TracedRNumber{Base.promote_type(T, S)} -end - -function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, x) -end - function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} if isa(rhs, TracedRArray) rhs isa TracedRArray{T,N} && return rhs @@ -254,103 +211,10 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} ) end -function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} - if isa(rhs, TracedRNumber) - rhs isa TracedRNumber{T} && return rhs - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.convert( - rhs.mlir_data; result=mlir_type(TracedRNumber{T}) - ), - 1, - ), - ) - end - if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) - return TracedRNumber{T}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - ) - end - T0 = eltype(rhs) - attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) - return promote_to( - TracedRNumber{T}, - TracedRNumber{T0}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - ), - ) -end - promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) - -for (jlop, hloop) in ( - (:(Base.min), :minimum), - (:(Base.max), :maximum), - (:(Base.:+), :add), - (:(Base.:-), :subtract), - (:(Base.:*), :multiply), - (:(Base.:/), :divide), - (:(Base.:^), :power), -) - @eval function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) - end -end - -function Base.ifelse( - @nospecialize(pred::TracedRNumber{Bool}), - @nospecialize(x::TracedRNumber{T1}), - @nospecialize(y::TracedRNumber{T2}) -) where {T1,T2} - return TracedRNumber{promote_type(T1, T2)}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 - ), - ) -end - -function Base.literal_pow( - ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} -) where {T,P} - return Base.literal_pow(^, x, Val(P)) -end - -for (jlop, hloop) in ( - (:(Base.abs), :abs), - (:(Base.:-), :negate), - (:(Base.sin), :sine), - (:(Base.cos), :cosine), - (:(Base.tanh), :tanh), - (:(Base.FastMath.tanh_fast), :tanh), - (:(Base.exp), :exponential), - (:(Base.FastMath.exp_fast), :exponential), - (:(Base.log), :log), - (:(Base.sqrt), :sqrt), -) - @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} - return TracedRNumber{T}( - (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) - ) - end -end struct TypeCast{T<:Number} <: Function end -function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} - return promote_to(TracedRNumber{T}, x) -end - elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number} # Special Path to prevent going down a despecialized path @@ -435,41 +299,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return traced2_result end -for (jlop, hloop, hlocomp, merge) in ( - (:(Base.:(==)), :compare, "EQ", :all), - (:(Base.:(!=)), :compare, "NE", :any), - (:(Base.:(>=)), :compare, "GE", nothing), - (:(Base.:(>)), :compare, "GT", nothing), - (:(Base.:(<=)), :compare, "LE", nothing), - (:(Base.:(<)), :compare, "LT", nothing), -) - @eval function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return TracedRNumber{Bool}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp - ), - ), - 1, - ), - ) - end - - if merge !== nothing - @eval begin - function $jlop( - @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N}) - ) where {T,N} - elems = $(jlop).(lhs, rhs) - return N == 0 ? elems : $(merge)(elems) - end - end +for (jlop, hloop, hlocomp, merge) in + ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) + @eval function $jlop( + @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N}) + ) where {T,N} + elems = $(jlop).(lhs, rhs) + return N == 0 ? elems : $(merge)(elems) end end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl new file mode 100644 index 0000000000..de3516840f --- /dev/null +++ b/src/TracedRNumber.jl @@ -0,0 +1,161 @@ +mutable struct TracedRNumber{T} <: RNumber{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRNumber{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + +Base.eltype(::Type{TracedRNumber{T}}) where {T} = T + +Base.getindex(a::TracedRNumber{T}) where {T} = a + +Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) +Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) + +function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, T(x)) +end + +function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} + return print(io, "TracedRNumber{", T, "}(", X.paths, ")") +end + +Base.only(A::TracedRNumber{T}) where {T} = A + +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} +end + +function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} + return TracedRNumber{Base.promote_type(T, S)} +end + +function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} + return promote_to(TracedRNumber{T}, x) +end + +function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} + if isa(rhs, TracedRNumber) + rhs isa TracedRNumber{T} && return rhs + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.convert( + rhs.mlir_data; result=mlir_type(TracedRNumber{T}) + ), + 1, + ), + ) + end + if isa(rhs, Number) + attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) + return TracedRNumber{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ) + end + T0 = eltype(rhs) + attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) + return promote_to( + TracedRNumber{T}, + TracedRNumber{T0}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) + ), + ) +end + +promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) + +for (jlop, hloop) in ( + (:(Base.min), :minimum), + (:(Base.max), :maximum), + (:(Base.:+), :add), + (:(Base.:-), :subtract), + (:(Base.:*), :multiply), + (:(Base.:/), :divide), + (:(Base.:^), :power), +) + @eval function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end +end + +for (jlop, hloop, hlocomp) in ( + (:(Base.:(==)), :compare, "EQ"), + (:(Base.:(!=)), :compare, "NE"), + (:(Base.:(>=)), :compare, "GE"), + (:(Base.:(>)), :compare, "GT"), + (:(Base.:(<=)), :compare, "LE"), + (:(Base.:(<)), :compare, "LT"), +) + @eval function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return TracedRNumber{Bool}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), + ), + 1, + ), + ) + end +end + +function Base.ifelse( + @nospecialize(pred::TracedRNumber{Bool}), + @nospecialize(x::TracedRNumber{T1}), + @nospecialize(y::TracedRNumber{T2}) +) where {T1,T2} + return TracedRNumber{promote_type(T1, T2)}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 + ), + ) +end + +function Base.literal_pow( + ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} +) where {T,P} + return Base.literal_pow(^, x, Val(P)) +end + +for (jlop, hloop) in ( + (:(Base.abs), :abs), + (:(Base.:-), :negate), + (:(Base.sin), :sine), + (:(Base.cos), :cosine), + (:(Base.tanh), :tanh), + (:(Base.FastMath.tanh_fast), :tanh), + (:(Base.exp), :exponential), + (:(Base.FastMath.exp_fast), :exponential), + (:(Base.log), :log), + (:(Base.sqrt), :sqrt), +) + @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} + return TracedRNumber{T}( + (), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1) + ) + end +end + +(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) From 2ca8c68284c195159fc4adb6a831cb578df99e23 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:20:13 -0400 Subject: [PATCH 09/34] feat: support Base.float --- src/TracedRNumber.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index de3516840f..e0f44ba84a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -159,3 +159,5 @@ for (jlop, hloop) in ( end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) + +Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) From 29759534e521fd92dccc81eb4faa05b5cc3fe435 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 21:21:40 -0400 Subject: [PATCH 10/34] fix: import ordering --- src/Reactant.jl | 2 +- src/TracedRArray.jl | 8 ++++---- src/TracedRNumber.jl | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 10155b75c7..81a57eb2c6 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -65,8 +65,8 @@ include("XLA.jl") include("Interpreter.jl") include("utils.jl") include("ConcreteRArray.jl") -include("TracedRArray.jl") include("TracedRNumber.jl") +include("TracedRArray.jl") include("Tracing.jl") include("Compiler.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f68dd51c2c..4b3dd861fd 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -213,10 +213,10 @@ end promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -struct TypeCast{T<:Number} <: Function end - -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x -function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number} +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitives} = x +function elem_apply( + ::Type{T}, x::TracedRArray{T2} +) where {T<:ReactantPrimitives,T2<:ReactantPrimitives} # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index e0f44ba84a..a0ff140c6d 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -158,6 +158,8 @@ for (jlop, hloop) in ( end end +struct TypeCast{T<:ReactantPrimitives} <: Function end + (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) From 8da3d20ba1136de91e18b8193be5f33afeb6b40b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 22:12:17 -0400 Subject: [PATCH 11/34] feat: handle `broadcast_preserving_zero_d` in a generic fashion --- src/ConcreteRArray.jl | 4 ---- src/TracedRArray.jl | 28 +++++++++++++++++++++++++--- src/TracedRNumber.jl | 11 +++++++++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 8f70f324d3..88e5af30f2 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T} return to_float(x) end -function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2} - return Base.promote_rule(T1, T2) -end - for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^)) @eval begin function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U} diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 4b3dd861fd..2f475482c7 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -66,6 +66,10 @@ and require expensive copies and synchronization each time and therefore should return TracedRNumber{T}((), res2) end +function Base.getindex(a::TracedRArray{T,0}) where {T} + return TracedRNumber{T}((), a.mlir_data) +end + function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)] res = MLIR.IR.result( @@ -222,7 +226,12 @@ function elem_apply( end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - all(iszero ∘ ndims, args) && return f(args...) + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return promote_to(TracedRNumber{eltype(arg)}, arg) + end + return f(scalar_args...) + end fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true @@ -440,6 +449,12 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} return A end +function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} + bcast = broadcast_to_size(promote_to(TracedRNumber{T}, x), size(A)) + A.mlir_data = bcast.mlir_data + return A +end + struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}() @@ -458,7 +473,14 @@ end function Base.similar( bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T,N} +) where {T<:ReactantPrimitives,N} + @assert N isa Int + return TracedRArray{T,N}((), nothing, map(length, dims)) +end + +function Base.similar( + bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{<:TracedRNumber{T}}, dims +) where {T<:ReactantPrimitives,N} @assert N isa Int return TracedRArray{T,N}((), nothing, map(length, dims)) end @@ -536,7 +558,7 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number} end function broadcast_to_size(arg::TracedRNumber, rsize) - rsize == () && return arg + length(rsize) == 0 && return arg mlirty = MLIR.IR.type(arg.mlir_data) return TracedRArray{eltype(arg),length(rsize)}( (), diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index a0ff140c6d..c2a023c257 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -54,6 +54,17 @@ function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} ), ) end + if isa(rhs, TracedRArray{<:Any,0}) + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.convert( + rhs.mlir_data; result=mlir_type(TracedRNumber{T}) + ), + 1, + ), + ) + end if isa(rhs, Number) attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) return TracedRNumber{T}( From 194ee65707134edde0c207f0cbcdd0b321657aa1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 22:13:29 -0400 Subject: [PATCH 12/34] refactor: move code a bit --- src/TracedRArray.jl | 67 ++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2f475482c7..06720ca9c0 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -455,18 +455,41 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} return A end +function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D} + @assert D isa Integer "Support for non-integer dimensions is not implemented yet." + + # MLIR expects the dimension `D` to be ≤ the rank of the input tensors + A = maybe_expand_dims(A, dims) + Bs = maybe_expand_dims.(Bs, (dims,)) + + catdims = Base.dims2cat(dims) + shape = Base.cat_size_shape(catdims, A, Bs...) + RT = Base.promote_eltype(A, Bs...) + Res = TracedRArray{RT,length(shape)}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.concatenate( + [A.mlir_data, [B.mlir_data for B in Bs]...]; + result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), + dimension=D - 1, # stablehlo expects this to be zero-indexed + ), + 1, + ), + shape, + ) + return Res +end + +function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D} + D ≤ N && return x + return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, Val(D))) +end + struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}() AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}() -# function Broadcast.materialize(bc::Broadcasted) -# @show bc -# inst = instantiate(bc) -# @show inst -# copy(inst) -# end - function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N} return AbstractReactantArrayStyle{N}() end @@ -628,33 +651,3 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted) dest.mlir_data = res.mlir_data return dest end - -function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D} - @assert D isa Integer "Support for non-integer dimensions is not implemented yet." - - # MLIR expects the dimension `D` to be ≤ the rank of the input tensors - A = maybe_expand_dims(A, dims) - Bs = maybe_expand_dims.(Bs, (dims,)) - - catdims = Base.dims2cat(dims) - shape = Base.cat_size_shape(catdims, A, Bs...) - RT = Base.promote_eltype(A, Bs...) - Res = TracedRArray{RT,length(shape)}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.concatenate( - [A.mlir_data, [B.mlir_data for B in Bs]...]; - result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), - dimension=D - 1, # stablehlo expects this to be zero-indexed - ), - 1, - ), - shape, - ) - return Res -end - -function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D} - D ≤ N && return x - return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, Val(D))) -end From db5565b0a98a55864d466a35250e4dae3155054b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 22:27:44 -0400 Subject: [PATCH 13/34] test: more test fixes --- src/Compiler.jl | 3 ++- src/TracedRNumber.jl | 54 ++++++++++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 94a4f35a35..147b76b6c4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -6,6 +6,7 @@ import ..Reactant: XLA, ConcreteRArray, TracedRArray, + TracedRNumber, OrderedIdDict, make_tracer, TracedToConcrete, @@ -289,7 +290,7 @@ function compile_mlir!(mod, f, args; optimize=true) preserved_args = Tuple{TracedRArray,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] - linear_results2 = TracedRArray[] + linear_results2 = Union{TracedRArray,TracedRNumber}[] for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index c2a023c257..37eb04bb5b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -112,22 +112,45 @@ for (jlop, hloop, hlocomp) in ( (:(Base.:(<=)), :compare, "LE"), (:(Base.:(<)), :compare, "LT"), ) - @eval function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return TracedRNumber{Bool}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp + @eval begin + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return TracedRNumber{Bool}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$(hloop)( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), $hlocomp + ), ), + 1, ), - 1, - ), - ) + ) + end + + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs) + ) where {T} + return $(jlop)(lhs, promote_to(lhs, rhs)) + end + + function $(jlop)( + @nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return $(jlop)(promote_to(rhs, lhs), rhs) + end + + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2}) + ) where {T1,T2} + commonTy = TracedRNumber{Base.promote_type(T1, T2)} + lhs = promote_to(commonTy, lhs) + rhs = promote_to(commonTy, rhs) + return $(jlop)(lhs, rhs) + end end end @@ -169,6 +192,9 @@ for (jlop, hloop) in ( end end +# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined +Base.abs2(x::TracedRNumber{<:Real}) = x^2 + struct TypeCast{T<:ReactantPrimitives} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) From d18aff8a42ad27f9ec6092d8f03b4aa27e3fc5d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 22:31:52 -0400 Subject: [PATCH 14/34] chore: apply formatting --- src/TracedRNumber.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 37eb04bb5b..1a4225f400 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -131,15 +131,11 @@ for (jlop, hloop, hlocomp) in ( ) end - function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs) - ) where {T} + function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} return $(jlop)(lhs, promote_to(lhs, rhs)) end - function $(jlop)( - @nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} + function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} return $(jlop)(promote_to(rhs, lhs), rhs) end From 7fd269da7cca3b48ba431ddf6e6b126f5ecde218 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 22:42:44 -0400 Subject: [PATCH 15/34] fix: setindex with scalars --- src/ConcreteRArray.jl | 2 +- src/TracedRArray.jl | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 88e5af30f2..9c2077b27b 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -154,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} end function mysetindex!(a, v, args::Vararg{Int,N}) where {N} - Base.setindex!(a, v, args...) + setindex!(a, v, args...) return nothing end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 06720ca9c0..37260f2a24 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -99,13 +99,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...) end function Base.setindex!( - a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N} + a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N} ) where {T,N} + indices = map(enumerate(indices)) do (idx, i) + i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) + end + v = broadcast_to_size(v, length.(indices)) + v = promote_to(TracedRArray{T,N}, v) indices = [ (promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for i in indices ] - v = promote_to(TracedRArray{T,N}, v) res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1 ) From 91a4a00fabfb6c72fe450f35633ae214791b4e48 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 23:16:38 -0400 Subject: [PATCH 16/34] fix: scalar broadcasting case --- src/TracedRArray.jl | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 37260f2a24..003a39a077 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -586,20 +586,8 @@ end function broadcast_to_size(arg::TracedRNumber, rsize) length(rsize) == 0 && return arg - mlirty = MLIR.IR.type(arg.mlir_data) - return TracedRArray{eltype(arg),length(rsize)}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - arg.mlir_data; - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute([ - Int64(i - 1) for i in rsize - ]), - ), - 1, - ), - rsize, + return broadcast_to_size_internal( + TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize ) end From d82fb522ecc8065bce5ec1320aa6861403566380 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 11:28:10 -0400 Subject: [PATCH 17/34] feat: support BFloat16 from Core (if available) --- src/Reactant.jl | 53 +++++++++++++++++++++++++++++++++---------------- src/XLA.jl | 4 +++- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 81a57eb2c6..6fef0a7707 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -7,23 +7,42 @@ include("OrderedIdDict.jl") using Enzyme -const ReactantPrimitives = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Float32, - # BFloat16, - Float64, - Complex{Float32}, - Complex{Float64}, -} +@static if isdefined(Core, :BFloat16) + const ReactantPrimitives = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Core.BFloat16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +else + const ReactantPrimitives = Union{ + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Float16, + Float32, + Float64, + Complex{Float32}, + Complex{Float64}, + } +end abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end abstract type RNumber{T<:ReactantPrimitives} <: Number end diff --git a/src/XLA.jl b/src/XLA.jl index 9e77dac6df..0ea8bc539e 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -227,7 +227,9 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 -# @inline primitive_type(::Type{BFloat16}) = 16 +@static if isdefined(Core, :BFloat16) + @inline primitive_type(::Type{BFloat16}) = 16 +end @inline primitive_type(::Type{Float64}) = 12 From 45158bb7440551115121a871cd95d6a6ebbb7494 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 11:36:33 -0400 Subject: [PATCH 18/34] test: more native lux functionality unblocked --- test/nn/lux.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/nn/lux.jl b/test/nn/lux.jl index b4a4558783..49fa37f52c 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -1,15 +1,8 @@ using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays -function crossentropy(ŷ, y) - logŷ = log.(ŷ) - result = y .* logŷ - return -sum(result) -end - function loss_function(model, x, y, ps, st) y_hat, _ = model(x, ps, st) - # return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps - return crossentropy(y_hat, y) + return CrossEntropyLoss()(y_hat, y) end function gradient_loss_function(model, x, y, ps, st) From 4757cf921e3bb146a462756808b09881d3d6bd8d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 11:39:51 -0400 Subject: [PATCH 19/34] refactor: use a union type for traced types --- src/Compiler.jl | 2 +- src/Reactant.jl | 4 ++++ src/utils.jl | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 147b76b6c4..c357e5beb3 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -290,7 +290,7 @@ function compile_mlir!(mod, f, args; optimize=true) preserved_args = Tuple{TracedRArray,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] - linear_results2 = Union{TracedRArray,TracedRNumber}[] + linear_results2 = TracedTypes[] for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) diff --git a/src/Reactant.jl b/src/Reactant.jl index 6fef0a7707..e70798302b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -83,9 +83,13 @@ include("mlir/MLIR.jl") include("XLA.jl") include("Interpreter.jl") include("utils.jl") + include("ConcreteRArray.jl") include("TracedRNumber.jl") include("TracedRArray.jl") + +const TracedTypes = Union{TracedRArray,TracedRNumber} + include("Tracing.jl") include("Compiler.jl") diff --git a/src/utils.jl b/src/utils.jl index dd5f834d56..e136ebb095 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,7 +44,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_args = Union{TracedRArray,TracedRNumber}[] + linear_args = TracedTypes[] for (k, v) in seen_args if !(v isa TracedRArray) && !(v isa TracedRNumber) continue @@ -127,7 +127,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_results = Union{TracedRArray,TracedRNumber}[] + linear_results = TracedTypes[] for (k, v) in seen_results if !(v isa TracedRArray) && !(v isa TracedRNumber) From c85d3a11ba794db9e7384a60e6d9949803724ae6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 11:55:26 -0400 Subject: [PATCH 20/34] fix: check for reactant primitives --- src/Tracing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index a94f6e28fd..0c70b4583d 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -427,7 +427,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat + if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitives return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode)) From d9cf4989cfb6c3aeeabd9a61cbc63dd399b170c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 12:00:01 -0400 Subject: [PATCH 21/34] fix: missing import --- src/Compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index c357e5beb3..825d33101f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -10,7 +10,8 @@ import ..Reactant: OrderedIdDict, make_tracer, TracedToConcrete, - append_path + append_path, + TracedTypes @inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) From 0d7ad84b52d2877eed59bfa12e0e795e5bb37fdd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 16:11:15 -0400 Subject: [PATCH 22/34] fix: correct semantics for Colon mapreduce --- src/Compiler.jl | 2 +- src/TracedRArray.jl | 6 +++++- test/basic.jl | 8 ++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 825d33101f..274cb2cbf8 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -575,7 +575,7 @@ end function compile_xla(f, args; client=nothing) # register MLIR dialects ctx = MLIR.IR.Context() - Base.append!(Reactant.registry[]; context=ctx) + append!(Reactant.registry[]; context=ctx) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid return MLIR.IR.context!(ctx) do diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 003a39a077..0b3c6a65a5 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -431,7 +431,11 @@ function Base.mapreduce( ) red = TracedRArray{T,length(toonedims)}((), red, (toonedims...,)) else - red = TracedRArray{T,length(outdims)}((), red, (outdims...,)) + if length(outdims) == 0 + red = TracedRNumber{T}((), red) + else + red = TracedRArray{T,length(outdims)}((), red, (outdims...,)) + end end return red end diff --git a/test/basic.jl b/test/basic.jl index 4467adf8f4..9116b3a7f8 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -65,6 +65,8 @@ end sumexp(x) = sum(exp, x) +sum_compare(x) = sum(x) > 0 + @testset "Basic mapreduce" begin x = rand(Float32, 10) a = Reactant.ConcreteRArray(x) @@ -74,6 +76,12 @@ sumexp(x) = sum(exp, x) f_res = f(a) @test f_res ≈ r_res + + # Ensure we are tracing as scalars. Else this will fail due to > not being defined on + # arrays + f = @compile sum_compare(a) + # We need to use [] to unwrap the scalar. We will fix this in the future. + @test f(a)[] == sum_compare(x) end function mysoftmax!(x) From d7337c92e2b857c7bc94df11d0b2bb88e638328e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 16:26:54 -0400 Subject: [PATCH 23/34] fix: trace_type --- src/Tracing.jl | 4 ++-- test/compile.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 0c70b4583d..3c30b3ba18 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -183,7 +183,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode end end -function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedRArray,mode} +function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedTypes,mode} if mode == ConcreteToTraced throw("TracedRArray $T cannot be traced") elseif mode == TracedToConcrete @@ -203,7 +203,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:XLAArray} end function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode} - if mode == ArrayToConcrete && T <: AbstractFloat + if mode == ArrayToConcrete && T <: ReactantPrimitives return ConcreteRArray{T,N} else return Array{traced_type(T, seen, Val(mode)),N} diff --git a/test/compile.jl b/test/compile.jl index c5944e4c76..85692ad72b 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -7,7 +7,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= @testset "create_result" begin @testset "NamedTuple" begin x = (; a=rand(4, 3)) - x2 = (; a=Reactant.ConcreteRArray(x.a)) + x2 = Reactant.to_rarray(x) f = @compile sum(x2) From 6aab7f79eba3766e745804d4480c04397bd1c2d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 16:48:40 -0400 Subject: [PATCH 24/34] fix: minor fixes --- src/TracedRNumber.jl | 2 ++ src/XLA.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 1a4225f400..946895506c 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -19,6 +19,8 @@ Base.getindex(a::TracedRNumber{T}) where {T} = a Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) +Base.eps(::Type{TracedRNumber{T}}) where {T} = promote_to(TracedRNumber{T}, eps(T)) + function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} return promote_to(TracedRNumber{T}, T(x)) end diff --git a/src/XLA.jl b/src/XLA.jl index 0ea8bc539e..556b6ff33a 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -228,7 +228,7 @@ end @inline primitive_type(::Type{Float32}) = 11 @static if isdefined(Core, :BFloat16) - @inline primitive_type(::Type{BFloat16}) = 16 + @inline primitive_type(::Type{Core.BFloat16}) = 16 end @inline primitive_type(::Type{Float64}) = 12 From abc6a9e9b1f78d393812e5ff985d665b6ef5737a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 17:26:02 -0400 Subject: [PATCH 25/34] feat: support logsoftmax --- ext/ReactantNNlibExt.jl | 14 ++++++++++++++ src/TracedRNumber.jl | 2 ++ 2 files changed, 16 insertions(+) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 347b24e72a..f7ee82c204 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -32,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where return out ./= tmp end +function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T} + max_ = NNlib.fast_maximum(x; dims) + # if all(isfinite, max_) + @fastmath out .= x .- max_ + # else + # _zero, _minf, _inf = T(0), T(-Inf), T(Inf) + # @. out = ifelse( + # isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_ + # ) + # end + @fastmath log_ = log.(sum(exp, out; dims)) + return out .-= log_ +end + function NNlib.conv( x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims ) where {T,N} diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 946895506c..c5de72ba96 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -193,6 +193,8 @@ end # XXX: Enzyme-MLIR doesn't have `abs` adjoint defined Base.abs2(x::TracedRNumber{<:Real}) = x^2 +Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) + struct TypeCast{T<:ReactantPrimitives} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) From 841376d6091c350a1dbb17c0d50e9b940c041734 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 17:42:55 -0400 Subject: [PATCH 26/34] fix: bool promote rule --- src/TracedRNumber.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index c5de72ba96..c16b10a05b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -35,6 +35,15 @@ function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) w return TracedRNumber{Base.promote_type(T, S)} end +# Bool has special promotion rules in Base +function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T} + return TracedRNumber{T} +end + +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T} + return TracedRNumber{T} +end + function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end From eb3d1db70fa28054c4b693d4b5325d778d9bc2a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 19:14:45 -0400 Subject: [PATCH 27/34] fix: broadcasting of closures --- src/TracedRArray.jl | 3 ++- src/utils.jl | 5 ++++- test/bcast.jl | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0b3c6a65a5..75fb609d64 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -246,7 +246,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} invmap[v] = k end - input_shapes = size.(keys(seen_args)) + keys_seen = [k for k in keys(seen_args) if k isa TracedTypes] + input_shapes = size.(keys_seen) # by the time we reach here all args must have same size @assert allequal(input_shapes) "input shapes are $(input_shapes)" OutShape = isempty(seen_args) ? nothing : first(input_shapes) diff --git a/src/utils.jl b/src/utils.jl index e136ebb095..734630d692 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -29,7 +29,10 @@ end function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false) if sizeof(typeof(f)) != 0 || f isa BroadcastFunction - return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...) + return ( + true, + make_mlir_fn(apply, (f, args...), kwargs, name, concretein; toscalar)[2:end]..., + ) end N = length(args) diff --git a/test/bcast.jl b/test/bcast.jl index 3294aeaab9..4e56015894 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -112,3 +112,23 @@ pow(x, n) = x .^ n @test pow_compiled(x_ra) ≈ pow(x, 2) end + +struct CustomBCastFunction{X} + x::X +end + +(f::CustomBCastFunction)(x::Number) = f.x + x + +function custombcast(x) + fn = CustomBCastFunction(3.0) + return fn.(x) +end + +@testset "Broadcasting closures / functors" begin + x = rand(2, 3) + x_ra = Reactant.to_rarray(x) + + custombcast_compiled = @compile custombcast(x_ra) + + @test custombcast_compiled(x_ra) ≈ custombcast(x) +end From 944dca84e0592d648008a327d3aa1ddf46445c98 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 5 Oct 2024 20:45:37 -0400 Subject: [PATCH 28/34] refactor: use TracedTypes --- src/utils.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 734630d692..ea2dae0de6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -49,9 +49,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa linear_args = TracedTypes[] for (k, v) in seen_args - if !(v isa TracedRArray) && !(v isa TracedRNumber) - continue - end + v isa TracedTypes || continue push!(linear_args, v) end @@ -133,10 +131,7 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa linear_results = TracedTypes[] for (k, v) in seen_results - if !(v isa TracedRArray) && !(v isa TracedRNumber) - continue - end - + v isa TracedTypes || continue push!(linear_results, v) end From 3ecafefd3a2b86e8cfa64b0d75df1e57854a2056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 21:49:14 -0400 Subject: [PATCH 29/34] Fix type of `preserved_args` --- src/Compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 274cb2cbf8..6c2b9b935d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -288,7 +288,7 @@ function compile_mlir!(mod, f, args; optimize=true) ) end - preserved_args = Tuple{TracedRArray,Int}[] + preserved_args = Tuple{TracedTypes,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] linear_results2 = TracedTypes[] From c03b5e0dae9405da9c44d9020d5e411dcf0f9037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 21:51:01 -0400 Subject: [PATCH 30/34] Rename `TracedTypes` to `TracedType` --- src/Compiler.jl | 6 ++-- src/Reactant.jl | 2 +- src/TracedRArray.jl | 2 +- src/Tracing.jl | 2 +- src/utils.jl | 8 ++--- test/basic.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 16 ++++----- 7 files changed, 101 insertions(+), 18 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 6c2b9b935d..4b07ab2b10 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -11,7 +11,7 @@ import ..Reactant: make_tracer, TracedToConcrete, append_path, - TracedTypes + TracedType @inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) @@ -288,10 +288,10 @@ function compile_mlir!(mod, f, args; optimize=true) ) end - preserved_args = Tuple{TracedTypes,Int}[] + preserved_args = Tuple{TracedType,Int}[] results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] - linear_results2 = TracedTypes[] + linear_results2 = TracedType[] for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) diff --git a/src/Reactant.jl b/src/Reactant.jl index e70798302b..3674050026 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -88,7 +88,7 @@ include("ConcreteRArray.jl") include("TracedRNumber.jl") include("TracedRArray.jl") -const TracedTypes = Union{TracedRArray,TracedRNumber} +const TracedType = Union{TracedRArray,TracedRNumber} include("Tracing.jl") include("Compiler.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6d3346a6fb..e0d4cdc40c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -246,7 +246,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} invmap[v] = k end - keys_seen = [k for k in keys(seen_args) if k isa TracedTypes] + keys_seen = [k for k in keys(seen_args) if k isa TracedType] input_shapes = size.(keys_seen) # by the time we reach here all args must have same size @assert allequal(input_shapes) "input shapes are $(input_shapes)" diff --git a/src/Tracing.jl b/src/Tracing.jl index 3c30b3ba18..833ab830c7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -183,7 +183,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode end end -function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedTypes,mode} +function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,mode} if mode == ConcreteToTraced throw("TracedRArray $T cannot be traced") elseif mode == TracedToConcrete diff --git a/src/utils.jl b/src/utils.jl index ea2dae0de6..b379366bc5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -47,9 +47,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_args = TracedTypes[] + linear_args = TracedType[] for (k, v) in seen_args - v isa TracedTypes || continue + v isa TracedType || continue push!(linear_args, v) end @@ -128,10 +128,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa ) end - linear_results = TracedTypes[] + linear_results = TracedType[] for (k, v) in seen_results - v isa TracedTypes || continue + v isa TracedType || continue push!(linear_results, v) end diff --git a/test/basic.jl b/test/basic.jl index de7cdf378b..5b7cb4d25f 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -218,6 +218,89 @@ end end @testset "concatenation" begin + @testset "Number" begin + x = fill(true) + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + test_vcat(x) = begin + x = x[] # unwrap scalar + [x; x; x] + end + f = @compile test_vcat(x_concrete) + @test f(x_concrete) == test_vcat(x) + @test eltype(f(x_concrete)) === Bool + + # hcat + test_hcat(x) = begin + x = x[] # unwrap scalar + [x x x] + end + f = @compile test_hcat(x_concrete) + @test f(x_concrete) == test_hcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvcat + test_hvcat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x] + end + f = @compile test_hvcat(x_concrete) + @test f(x_concrete) == test_hvcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvncat + test_hvncat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x;;; x x x; x x x] + end + f = @compile test_hvncat(x_concrete) + @test f(x_concrete) == test_hvncat(x) + @test eltype(f(x_concrete)) === Bool + + # typed_vcat + test_typed_vcat(x) = begin + x = x[] # unwrap scalar + Int[x; x; x] + end + f = @compile test_typed_vcat(x_concrete) + @test f(x_concrete) == test_typed_vcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hcat + test_typed_hcat(x) = begin + x = x[] # unwrap scalar + Int[x x x] + end + f = @compile test_typed_hcat(x_concrete) + @test f(x_concrete) == test_typed_hcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvcat + test_typed_hvcat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x] + end + f = @compile test_typed_hvcat(x_concrete) + @test f(x_concrete) == test_typed_hvcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvncat + test_typed_hvncat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x;;; x x x; x x x] + end + f = @compile test_typed_hvncat(x_concrete) + @test f(x_concrete) == test_typed_hvncat(x) + @test eltype(f(x_concrete)) === Int + end + @testset "$(ndims(x))-dim" for x in [ fill(true), [true, false], diff --git a/test/runtests.jl b/test/runtests.jl index 9fe7423096..1fe0b311ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,15 +43,15 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - @safetestset "Layout" include("layout.jl") - @safetestset "Tracing" include("tracing.jl") + # @safetestset "Layout" include("layout.jl") + # @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") - @safetestset "Broadcast" include("bcast.jl") - @safetestset "Struct" include("struct.jl") - @safetestset "Closure" include("closure.jl") - @safetestset "Compile" include("compile.jl") - @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") + # @safetestset "Broadcast" include("bcast.jl") + # @safetestset "Struct" include("struct.jl") + # @safetestset "Closure" include("closure.jl") + # @safetestset "Compile" include("compile.jl") + # @safetestset "Buffer Donation" include("buffer_donation.jl") + # @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 60b614b3bd540f2da44ee43a44aa1ce0bca52d89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 21:57:21 -0400 Subject: [PATCH 31/34] small testset rename --- test/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basic.jl b/test/basic.jl index 5b7cb4d25f..56d4a121dc 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -301,7 +301,7 @@ end @test eltype(f(x_concrete)) === Int end - @testset "$(ndims(x))-dim" for x in [ + @testset "$(ndims(x))-dim Array" for x in [ fill(true), [true, false], [true false], From 8a9f06c19707cce996a1b19b1912c39f03ae73b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 6 Oct 2024 09:06:12 -0400 Subject: [PATCH 32/34] fix: special handling for concatenation of numbers --- src/TracedRNumber.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index c16b10a05b..0123c1982e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -209,3 +209,34 @@ struct TypeCast{T<:ReactantPrimitives} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) + +# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays +Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) +function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} + return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...) +end + +Base.hcat(x::TracedRNumber...) = Base.typed_hcat(Base.promote_eltypeof(x...), x...) +function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T} + return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...) +end + +function Base.hvcat(rows::Tuple{Vararg{Int}}, xs::TracedRNumber...) + return Base.typed_hvcat(Base.promote_eltypeof(xs...), rows, xs...) +end +function Base.typed_hvcat( + ::Type{T}, rows::Tuple{Vararg{Int}}, xs::TracedRNumber... +) where {T} + xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + return Base.typed_hvcat(T, rows, xs...) +end + +function Base.hvncat(dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber...) + return Base.typed_hvncat(Base.promote_eltypeof(xs...), dims, row_first, xs...) +end +function Base.typed_hvncat( + ::Type{T}, dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber... +) where {T} + xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + return Base.typed_hvncat(T, dims, row_first, xs...) +end From a35d7b7da38e4f366bae1d045a6000fbc13b43f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 13:34:03 -0400 Subject: [PATCH 33/34] Reenable tests --- test/runtests.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1fe0b311ea..9fe7423096 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,15 +43,15 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - # @safetestset "Layout" include("layout.jl") - # @safetestset "Tracing" include("tracing.jl") + @safetestset "Layout" include("layout.jl") + @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") - # @safetestset "Broadcast" include("bcast.jl") - # @safetestset "Struct" include("struct.jl") - # @safetestset "Closure" include("closure.jl") - # @safetestset "Compile" include("compile.jl") - # @safetestset "Buffer Donation" include("buffer_donation.jl") - # @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") + @safetestset "Broadcast" include("bcast.jl") + @safetestset "Struct" include("struct.jl") + @safetestset "Closure" include("closure.jl") + @safetestset "Compile" include("compile.jl") + @safetestset "Buffer Donation" include("buffer_donation.jl") + @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 4a81556a8fc68510d103b3e51b0b01093964c64e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 6 Oct 2024 13:35:28 -0400 Subject: [PATCH 34/34] Rename `ReactantPrimitives` to `ReactantPrimitive` --- src/Reactant.jl | 8 ++++---- src/TracedRArray.jl | 8 ++++---- src/TracedRNumber.jl | 2 +- src/Tracing.jl | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 3674050026..82bf06ea60 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -8,7 +8,7 @@ include("OrderedIdDict.jl") using Enzyme @static if isdefined(Core, :BFloat16) - const ReactantPrimitives = Union{ + const ReactantPrimitive = Union{ Bool, Int8, UInt8, @@ -26,7 +26,7 @@ using Enzyme Complex{Float64}, } else - const ReactantPrimitives = Union{ + const ReactantPrimitive = Union{ Bool, Int8, UInt8, @@ -44,8 +44,8 @@ else } end -abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end -abstract type RNumber{T<:ReactantPrimitives} <: Number end +abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end +abstract type RNumber{T<:ReactantPrimitive} <: Number end function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) return reshape(A, Base._reshape_uncolon(A, dims)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e0d4cdc40c..bafea85a9f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -221,10 +221,10 @@ end promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitives} = x +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x function elem_apply( ::Type{T}, x::TracedRArray{T2} -) where {T<:ReactantPrimitives,T2<:ReactantPrimitives} +) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end @@ -475,14 +475,14 @@ end function Base.similar( bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T<:ReactantPrimitives,N} +) where {T<:ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,N}((), nothing, map(length, dims)) end function Base.similar( bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{<:TracedRNumber{T}}, dims -) where {T<:ReactantPrimitives,N} +) where {T<:ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,N}((), nothing, map(length, dims)) end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 0123c1982e..b41dfcb481 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -204,7 +204,7 @@ Base.abs2(x::TracedRNumber{<:Real}) = x^2 Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) -struct TypeCast{T<:ReactantPrimitives} <: Function end +struct TypeCast{T<:ReactantPrimitive} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) diff --git a/src/Tracing.jl b/src/Tracing.jl index 833ab830c7..dd6c4e77cf 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -203,7 +203,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:XLAArray} end function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode} - if mode == ArrayToConcrete && T <: ReactantPrimitives + if mode == ArrayToConcrete && T <: ReactantPrimitive return ConcreteRArray{T,N} else return Array{traced_type(T, seen, Val(mode)),N} @@ -427,7 +427,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitives + if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode))