From f9d99b8a2c12d1e8d36e922cd410baf4abd50eb9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 14:37:23 -0400 Subject: [PATCH 1/2] feat: specialize dispatches for faster concrete array generation --- src/Tracing.jl | 24 +++++++++++++++++++++++- test/tracing.jl | 14 ++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index b6037d7c9b..84a73d4e55 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -563,7 +563,29 @@ end @inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) + return to_rarray_internal(x, track_numbers) +end + +@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple) return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) end -to_rarray(x::ReactantPrimitive) = ConcreteRArray(x) +function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) + error("Cannot convert TracedRArray to ConcreteRArray") +end +@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple +) + return ConcreteRArray(x) +end + +@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::ReactantPrimitive), track_numbers::Tuple +) + for T in track_numbers + typeof(x) <: T && return ConcreteRNumber(x) + end + return x +end diff --git a/test/tracing.jl b/test/tracing.jl index d75a435988..88b1e7662c 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -100,4 +100,18 @@ using Test end end end + + @testset "specialized dispatches" begin + @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( + 1.0; track_numbers=(Number,) + ) isa ConcreteRNumber + @test @inferred Reactant.to_rarray(1.0) isa Float64 + @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray + + x_ra = Reactant.to_rarray(rand(3)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray + + x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber + end end From b55f4430b6af05f41c637aea3396ae30a550f3cf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 14:46:48 -0400 Subject: [PATCH 2/2] chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Tracing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 84a73d4e55..30a617cc38 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -571,7 +571,7 @@ end end function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) - error("Cannot convert TracedRArray to ConcreteRArray") + return error("Cannot convert TracedRArray to ConcreteRArray") end @inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x @inline function to_rarray_internal(