@@ -2,14 +2,13 @@ module DiffEqBaseEnzymeExt
22
33using DiffEqBase
44import DiffEqBase: value
5- isdefined (Base, :get_extension ) ? ( import Enzyme) : ( import .. Enzyme)
6-
5+ using Enzyme
6+ import Enzyme : Const
77using ChainRulesCore
8- using EnzymeCore
98
10- function EnzymeCore . EnzymeRules. augmented_primal (config:: EnzymeCore .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
9+ function Enzyme . EnzymeRules. augmented_primal (config:: Enzyme .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
1110 @inline function copy_or_reuse (val, idx)
12- if EnzymeCore . EnzymeRules. overwritten (config)[idx] && ismutable (val)
11+ if Enzyme . EnzymeRules. overwritten (config)[idx] && ismutable (val)
1312 return deepcopy (val)
1413 else
1514 return val
@@ -28,15 +27,15 @@ function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.
2827 v.= 0
2928 end
3029 tup = (dres, res[2 ])
31- return EnzymeCore . EnzymeRules. AugmentedReturn {RT, RT, Any} (res[1 ], dres, tup:: Any )
30+ return Enzyme . EnzymeRules. AugmentedReturn {RT, RT, Any} (res[1 ], dres, tup:: Any )
3231end
3332
34- function EnzymeCore . EnzymeRules . reverse (config:: EnzymeCore .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{<:Duplicated{RT}} , tape, prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
33+ function Enzyme . reverse (config:: Enzyme .EnzymeRules.ConfigWidth{1} , func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{<:Duplicated{RT}} , tape, prob, sensealg:: Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}} , u0, p, args... ; kwargs... ) where RT
3534 dres, clos = tape
3635 dres = dres:: RT
3736 dargs = clos (dres)
3837 for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
39- if ptr isa EnzymeCore . Const
38+ if ptr isa Enzyme . Const
4039 continue
4140 end
4241 if darg == ChainRulesCore. NoTangent ()
0 commit comments