diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a2af3e485..447ed182c 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -6,18 +6,32 @@ const have_not_warned_vjp = Ref(true) const STACKTRACE_WITH_VJPWARN = Ref(false) +function adfunc(out, u, _p, t, repack) + f(out, u, repack(_p), t) + nothing +end + function inplace_vjp(prob, u0, p, verbose, repack) du = zero(u0) ez = try f = unwrapped_f(prob.f) - function adfunc(out, u, _p, t) - f(out, u, repack(_p), t) - nothing - end Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, copy(u0)), - Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1])) + Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack)) + true + catch e + false + end + if ez + return EnzymeVJP() + end + + erz = try + f = unwrapped_f(prob.f) + + Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Reverse), adfunc, Enzyme.Duplicated(du, copy(u0)), + Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack)) true catch e if verbose && have_not_warned_vjp[] @@ -28,8 +42,8 @@ function inplace_vjp(prob, u0, p, verbose, repack) end false end - if ez - return EnzymeVJP() + if erz + return EnzymeVJP(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)) end # Determine if we can compile ReverseDiff diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 3e1cd782e..04e3128ff 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -664,6 +664,16 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, return dy, dλ, dgrad end +function gclosure1(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end + +function gclosure2(du, u, p, t, W) + Base.copyto!(du, f(u, p, t, W)) + nothing +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, W) where {TS <: SensitivityFunction} (; sensealg) = S @@ -732,13 +742,13 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, end if W === nothing - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), + Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t)) else - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), + Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, @@ -750,22 +760,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy !== nothing && recursive_copyto!(dy, tmp3) else if W === nothing - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - _tmp6 = Enzyme.make_zero(g) - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6), + _tmp6 = Enzyme.make_zero(f) + Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure1), Enzyme.Duplicated(f, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t)) else - function g(du, u, p, t, W) - du .= f(u, p, t, W) - nothing - end - _tmp6 = Enzyme.make_zero(g) - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6), + _tmp6 = Enzyme.make_zero(f) + Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure2), Enzyme.Duplicated(f, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t), Enzyme.Const(W)) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 2c176d85f..a3270e0c7 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -441,6 +441,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) sensealg, dgdp_cache, dgdp) end +function g(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end + # out = λ df(u, p, t)/dp at u=y, p=p, t=t function vec_pjac!(out, λ, y, t, S::GaussIntegrand) (; pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol) = S @@ -500,17 +505,13 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) else - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - tmp6 = Enzyme.make_zero(g) + tmp6 = Enzyme.make_zero(f) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Const(gclosure3), Enzyme.Duplicated(f, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 696023e9c..45bc80ef0 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -232,6 +232,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp) end + +function gclosure4(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end # out = λ df(u, p, t)/dp at u=y, p=p, t=t function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) @@ -295,17 +300,13 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) if SciMLBase.isinplace(sol.prob.f) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), dup, Enzyme.Const(t)) else - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - tmp6 = Enzyme.make_zero(g) + tmp6 = Enzyme.make_zero(f) Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.Reverse), Enzyme.Duplicated(g, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Const(gclosure4), Enzyme.Duplicated(f, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), dup, Enzyme.Const(t)) end diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index f63b8f0ec..ed5a104c9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1306,7 +1306,7 @@ like BLAS/LAPACK are used) and this will be the most efficient adjoint implement ## Constructor ```julia -EnzymeVJP(; chunksize = 0) +EnzymeVJP(; chunksize = 0, mode = EnzymeCore.Reverse) ``` ## Keyword Arguments @@ -1317,12 +1317,15 @@ EnzymeVJP(; chunksize = 0) should be set to the maximum chunksize that can occur during an integration to preallocate the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` but could be decreased if this value is known to be lower to conserve memory. + - `mode`: the parameterized Enzyme mode, default set to EnzymeCore.Reverse. Alternatively one + may want to pass Enzyme.set_runtime_activity(Enzyme.Reverse) """ -struct EnzymeVJP <: VJPChoice +struct EnzymeVJP{Mode<:Enzyme.ReverseMode} <: VJPChoice chunksize::Int + mode::Mode end -EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) +EnzymeVJP(; chunksize = 0, mode = Enzyme.Reverse) = EnzymeVJP(chunksize, mode) """ ```julia