Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
version = "1.18.0"
version = "1.19.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ AutoChainRules
AutoDiffractor
```

### Forward, reverse, or sparse mode

```@docs
AutoReactant{<:AutoEnzyme}
```

### Symbolic mode

```@docs
Expand Down
3 changes: 2 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ export AutoChainRules,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError
NoAutoDiffSelectedError,
AutoReactant
@public AbstractMode
@public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode
@public mode
Expand Down
39 changes: 39 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,45 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A}
print(io, ")")
end


"""
AutoReactant{M<:AutoEnzyme}

Struct used to select the [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) compilation atop Enzyme for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoReactant(; mode::Union{AutoEnzyme,Nothing}=nothing)

# Fields

- `mode::M` specifies the Enzyme mode of differentiation

+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
+ `nothing` to choose the best mode automatically
"""
struct AutoReactant{M<:AutoEnzyme} <: AbstractADType
mode::M
end

function AutoReactant(;
mode::Union{AutoEnzyme,Nothing} = nothing)
if mode === nothing
mode = AutoEnzyme()
end
return AutoReactant(mode)
end

mode(r::AutoReactant) = mode(r.mode)

function Base.show(io::IO, backend::AutoReactant)
print(io, AutoReactant, "(")
print(io, "mode=", repr(backend.mode; context = io))
print(io, ")")
end

"""
AutoFastDifferentiation

Expand Down
2 changes: 1 addition & 1 deletion src/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ADTypes.AutoZygote()
"""
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)

for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :Mooncake, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
Expand Down
28 changes: 28 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ end
@test ad.mode == EnzymeCore.Reverse
end

@testset "AutoReactant" begin
ad = AutoReactant()
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme}
@test ad.mode isa AutoEnzyme
@test ad.mode.mode === nothing
@test mode(ad) isa ForwardOrReverseMode

ad = AutoReactant(; mode=AutoEnzyme(; mode = EnzymeCore.Forward))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}}
@test mode(ad) isa ForwardMode
@test ad.mode.mode == EnzymeCore.Forward

ad = AutoReactant(; mode=AutoEnzyme(; function_annotation = EnzymeCore.Const))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{Nothing, EnzymeCore.Const}}
@test mode(ad) isa ForwardOrReverseMode
@test ad.mode.mode === nothing

ad = AutoReactant(; mode=AutoEnzyme(;
mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}}
@test mode(ad) isa ReverseMode
@test ad.mode.mode == EnzymeCore.Reverse
end

@testset "AutoFastDifferentiation" begin
ad = AutoFastDifferentiation()
@test ad isa AbstractADType
Expand Down
1 change: 1 addition & 0 deletions test/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Test
@test ADTypes.Auto(:Mooncake) isa AutoMooncake
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
@test ADTypes.Auto(:Reactant) isa AutoReactant
@test ADTypes.Auto(:Symbolics) isa AutoSymbolics
@test ADTypes.Auto(:Tapir) isa AutoTapir
@test ADTypes.Auto(:Tracker) isa AutoTracker
Expand Down
Loading