From 0d8fbbc3c77154432cf9473e11f35e15b0ccee7d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 7 Nov 2025 16:24:48 -0600 Subject: [PATCH 1/3] Add AutoReactant{AutoEnzyme} --- Project.toml | 2 +- docs/src/index.md | 2 ++ src/ADTypes.jl | 3 ++- src/dense.jl | 39 +++++++++++++++++++++++++++++++++++++++ src/symbols.jl | 2 +- test/dense.jl | 28 ++++++++++++++++++++++++++++ test/symbols.jl | 1 + 7 files changed, 74 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 6860184..5218fb3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit , Guillaume Dalle and contributors"] -version = "1.18.0" +version = "1.19.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 7ae4eed..526b1a7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -53,6 +53,7 @@ AutoZygote AutoEnzyme AutoChainRules AutoDiffractor +AutoReactant{<:AutoEnzyme} ``` ### Symbolic mode @@ -67,6 +68,7 @@ AutoSymbolics ```@docs AutoSparse ADTypes.dense_ad +AutoReactant{<:AutoEnzyme} ``` ### Sparsity detector diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 379b02f..609087e 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -47,7 +47,8 @@ export AutoChainRules, AutoTracker, AutoZygote, NoAutoDiff, - NoAutoDiffSelectedError + NoAutoDiffSelectedError, + AutoReactant @public AbstractMode @public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode @public mode diff --git a/src/dense.jl b/src/dense.jl index 0a3534d..ea09cdd 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -82,6 +82,45 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A} print(io, ")") end + +""" + AutoReactant{M<:AutoEnzyme} + +Struct used to select the [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) compilation atop Enzyme for automatic differentiation. + +Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + +# Constructors + + AutoReactant(; mode::Union{AutoEnzyme,Nothing}=nothing) + +# Fields + + - `mode::M` specifies the Enzyme mode of differentiation + + + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + + `nothing` to choose the best mode automatically +""" +struct AutoReactant{M<:AutoEnzyme} <: AbstractADType + mode::M +end + +function AutoReactant(; + mode::Union{AutoEnzyme,Nothing} = nothing) + if mode == nothing + mode = AutoEnzyme() + end + return AutoReactant(mode) +end + +mode(r::AutoReactant) = mode(r.mode) + +function Base.show(io::IO, backend::AutoReactant) + print(io, AutoReactant, "(") + print(io, "mode=", repr(backend.mode; context = io)) + print(io, ")") +end + """ AutoFastDifferentiation diff --git a/src/symbols.jl b/src/symbols.jl index 58199f3..b02901f 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -22,7 +22,7 @@ ADTypes.AutoZygote() """ Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) -for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, +for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation, :FiniteDiff, :FiniteDifferences, :ForwardDiff, :Mooncake, :PolyesterForwardDiff, :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))( diff --git a/test/dense.jl b/test/dense.jl index 28f3aa4..2eee7e1 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -52,6 +52,34 @@ end @test ad.mode == EnzymeCore.Reverse end +@testset "AutoReactant" begin + ad = AutoReactant() + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme} + @test ad.mode isa AutoEnzyme + @test ad.mode.mode === nothing + @test mode(ad) isa ForwardOrReverseMode + + ad = AutoReactant(; mode=AutoEnzyme(; mode = EnzymeCore.Forward)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}} + @test mode(ad) isa ForwardMode + @test ad.mode.mode == EnzymeCore.Forward + + ad = AutoReactant(; mode=AutoEnzyme(; function_annotation = EnzymeCore.Const)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{Nothing, EnzymeCore.Const}} + @test mode(ad) isa ForwardOrReverseMode + @test ad.mode.mode === nothing + + ad = AutoReactant(; mode=AutoEnzyme(; + mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}} + @test mode(ad) isa ReverseMode + @test ad.mode.mode == EnzymeCore.Reverse +end + @testset "AutoFastDifferentiation" begin ad = AutoFastDifferentiation() @test ad isa AbstractADType diff --git a/test/symbols.jl b/test/symbols.jl index 3ebf0e2..10802cc 100644 --- a/test/symbols.jl +++ b/test/symbols.jl @@ -11,6 +11,7 @@ using Test @test ADTypes.Auto(:Mooncake) isa AutoMooncake @test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff @test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff +@test ADTypes.Auto(:Reactant) isa AutoReactant @test ADTypes.Auto(:Symbolics) isa AutoSymbolics @test ADTypes.Auto(:Tapir) isa AutoTapir @test ADTypes.Auto(:Tracker) isa AutoTracker From 7434348cf2d8fbdd63e3d0eb460839889ab29c2e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 7 Nov 2025 16:51:03 -0600 Subject: [PATCH 2/3] Update index.md with new mode documentation Added documentation for forward, reverse, or sparse mode. --- docs/src/index.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 526b1a7..2bed335 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -53,6 +53,11 @@ AutoZygote AutoEnzyme AutoChainRules AutoDiffractor +``` + +### Forward, reverse, or sparse mode + +```@docs AutoReactant{<:AutoEnzyme} ``` @@ -68,7 +73,6 @@ AutoSymbolics ```@docs AutoSparse ADTypes.dense_ad -AutoReactant{<:AutoEnzyme} ``` ### Sparsity detector From 48288e6f6da3bc51a7c1aad0f5d0ccb0f5ffbeb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 7 Nov 2025 17:11:23 -0600 Subject: [PATCH 3/3] Update src/dense.jl Co-authored-by: Avik Pal --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index ea09cdd..edd7d55 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -107,7 +107,7 @@ end function AutoReactant(; mode::Union{AutoEnzyme,Nothing} = nothing) - if mode == nothing + if mode === nothing mode = AutoEnzyme() end return AutoReactant(mode)