diff --git a/Project.toml b/Project.toml index 6860184..5218fb3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit , Guillaume Dalle and contributors"] -version = "1.18.0" +version = "1.19.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 7ae4eed..2bed335 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -55,6 +55,12 @@ AutoChainRules AutoDiffractor ``` +### Forward, reverse, or sparse mode + +```@docs +AutoReactant{<:AutoEnzyme} +``` + ### Symbolic mode ```@docs diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 379b02f..609087e 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -47,7 +47,8 @@ export AutoChainRules, AutoTracker, AutoZygote, NoAutoDiff, - NoAutoDiffSelectedError + NoAutoDiffSelectedError, + AutoReactant @public AbstractMode @public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode @public mode diff --git a/src/dense.jl b/src/dense.jl index 0a3534d..edd7d55 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -82,6 +82,45 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A} print(io, ")") end + +""" + AutoReactant{M<:AutoEnzyme} + +Struct used to select the [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) compilation atop Enzyme for automatic differentiation. + +Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + +# Constructors + + AutoReactant(; mode::Union{AutoEnzyme,Nothing}=nothing) + +# Fields + + - `mode::M` specifies the Enzyme mode of differentiation + + + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + + `nothing` to choose the best mode automatically +""" +struct AutoReactant{M<:AutoEnzyme} <: AbstractADType + mode::M +end + +function AutoReactant(; + mode::Union{AutoEnzyme,Nothing} = nothing) + if mode === nothing + mode = AutoEnzyme() + end + return AutoReactant(mode) +end + +mode(r::AutoReactant) = mode(r.mode) + +function Base.show(io::IO, backend::AutoReactant) + print(io, AutoReactant, "(") + print(io, "mode=", repr(backend.mode; context = io)) + print(io, ")") +end + """ AutoFastDifferentiation diff --git a/src/symbols.jl b/src/symbols.jl index 58199f3..b02901f 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -22,7 +22,7 @@ ADTypes.AutoZygote() """ Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) -for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, +for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation, :FiniteDiff, :FiniteDifferences, :ForwardDiff, :Mooncake, :PolyesterForwardDiff, :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))( diff --git a/test/dense.jl b/test/dense.jl index 28f3aa4..2eee7e1 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -52,6 +52,34 @@ end @test ad.mode == EnzymeCore.Reverse end +@testset "AutoReactant" begin + ad = AutoReactant() + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme} + @test ad.mode isa AutoEnzyme + @test ad.mode.mode === nothing + @test mode(ad) isa ForwardOrReverseMode + + ad = AutoReactant(; mode=AutoEnzyme(; mode = EnzymeCore.Forward)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}} + @test mode(ad) isa ForwardMode + @test ad.mode.mode == EnzymeCore.Forward + + ad = AutoReactant(; mode=AutoEnzyme(; function_annotation = EnzymeCore.Const)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{Nothing, EnzymeCore.Const}} + @test mode(ad) isa ForwardOrReverseMode + @test ad.mode.mode === nothing + + ad = AutoReactant(; mode=AutoEnzyme(; + mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated)) + @test ad isa AbstractADType + @test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}} + @test mode(ad) isa ReverseMode + @test ad.mode.mode == EnzymeCore.Reverse +end + @testset "AutoFastDifferentiation" begin ad = AutoFastDifferentiation() @test ad isa AbstractADType diff --git a/test/symbols.jl b/test/symbols.jl index 3ebf0e2..10802cc 100644 --- a/test/symbols.jl +++ b/test/symbols.jl @@ -11,6 +11,7 @@ using Test @test ADTypes.Auto(:Mooncake) isa AutoMooncake @test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff @test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff +@test ADTypes.Auto(:Reactant) isa AutoReactant @test ADTypes.Auto(:Symbolics) isa AutoSymbolics @test ADTypes.Auto(:Tapir) isa AutoTapir @test ADTypes.Auto(:Tracker) isa AutoTracker