@@ -39,36 +39,46 @@ struct AutoDiffractor <: AbstractADType end
3939mode (:: AutoDiffractor ) = ForwardOrReverseMode ()
4040
4141"""
42- AutoEnzyme{M}
42+ AutoEnzyme{M,A }
4343
4444Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
4545
4646Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4747
4848# Constructors
4949
50- AutoEnzyme(; mode=nothing)
50+ AutoEnzyme(; mode::M=nothing, function_annotation::Type{A}=Nothing)
51+
52+ # Type parameters
53+
54+ - `A` determines how the function `f` to differentiate is passed to Enzyme. It can be:
55+
56+ + a subtype of `EnzymeCore.Annotation` (like `EnzymeCore.Const` or `EnzymeCore.Duplicated`) to enforce a given annotation
57+ + `Nothing` to simply pass `f` and let Enzyme choose the most appropriate annotation
5158
5259# Fields
5360
54- - `mode::M`: can be either
61+ - `mode::M` determines the autodiff mode (forward or reverse). It can be:
5562
5663 + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
5764 + `nothing` to choose the best mode automatically
5865"""
59- struct AutoEnzyme{M} <: AbstractADType
66+ struct AutoEnzyme{M, A } <: AbstractADType
6067 mode:: M
6168end
6269
63- function AutoEnzyme (; mode:: M = nothing ) where {M}
64- return AutoEnzyme {M} (mode)
70+ function AutoEnzyme (;
71+ mode:: M = nothing , function_annotation:: Type{A} = Nothing) where {M, A}
72+ return AutoEnzyme {M, A} (mode)
6573end
6674
6775mode (:: AutoEnzyme ) = ForwardOrReverseMode () # specialized in the extension
6876
69- function Base. show (io:: IO , backend:: AutoEnzyme )
77+ function Base. show (io:: IO , backend:: AutoEnzyme{M, A} ) where {M, A}
7078 print (io, AutoEnzyme, " (" )
7179 ! isnothing (backend. mode) && print (io, " mode=" , repr (backend. mode; context = io))
80+ ! isnothing (backend. mode) && ! (A <: Nothing ) && print (io, " , " )
81+ ! (A <: Nothing ) && print (io, " function_annotation=" , repr (A; context = io))
7282 print (io, " )" )
7383end
7484
0 commit comments