From 1db23ff69557ae4218cab81510f5c0a3f582cc73 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 10 Nov 2025 15:52:50 -0600 Subject: [PATCH 1/7] Enzyme: Fix runtime activity errors --- src/concrete_solve.jl | 28 +++++++++++++++++++++------- src/derivative_wrappers.jl | 30 ++++++++++++++++-------------- src/gauss_adjoint.jl | 15 ++++++++------- src/quadrature_adjoint.jl | 15 ++++++++------- src/sensitivity_algorithms.jl | 9 ++++++--- 5 files changed, 59 insertions(+), 38 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a2af3e485..6836797a5 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -5,6 +5,11 @@ const have_not_warned_vjp = Ref(true) const STACKTRACE_WITH_VJPWARN = Ref(false) + +function adfunc(out, u, _p, t, repack) + f(out, u, repack(_p), t) + nothing +end function inplace_vjp(prob, u0, p, verbose, repack) du = zero(u0) @@ -12,12 +17,21 @@ function inplace_vjp(prob, u0, p, verbose, repack) ez = try f = unwrapped_f(prob.f) - function adfunc(out, u, _p, t) - f(out, u, repack(_p), t) - nothing - end Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(du, copy(u0)), - Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1])) + Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack)) + true + catch e + false + end + if ez + return EnzymeVJP() + end + + erz = try + f = unwrapped_f(prob.f) + + Enzyme.autodiff(Enzyme.set_runtime_activity(Enzyme.Reverse), adfunc, Enzyme.Duplicated(du, copy(u0)), + Enzyme.Duplicated(copy(u0), zero(u0)), Enzyme.Duplicated(copy(p), zero(p)), Enzyme.Const(prob.tspan[1]), Enzyme.Const(repack)) true catch e if verbose && have_not_warned_vjp[] @@ -28,8 +42,8 @@ function inplace_vjp(prob, u0, p, verbose, repack) end false end - if ez - return EnzymeVJP() + if erz + return EnzymeVJP(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)) end # Determine if we can compile ReverseDiff diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 3e1cd782e..04e3128ff 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -664,6 +664,16 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, return dy, dλ, dgrad end +function gclosure1(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end + +function gclosure2(du, u, p, t, W) + Base.copyto!(du, f(u, p, t, W)) + nothing +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, W) where {TS <: SensitivityFunction} (; sensealg) = S @@ -732,13 +742,13 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, end if W === nothing - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), + Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t)) else - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), + Enzyme.autodiff(isautojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, @@ -750,22 +760,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy !== nothing && recursive_copyto!(dy, tmp3) else if W === nothing - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - _tmp6 = Enzyme.make_zero(g) - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6), + _tmp6 = Enzyme.make_zero(f) + Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure1), Enzyme.Duplicated(f, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t)) else - function g(du, u, p, t, W) - du .= f(u, p, t, W) - nothing - end - _tmp6 = Enzyme.make_zero(g) - Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(g, _tmp6), + _tmp6 = Enzyme.make_zero(f) + Enzyme.autodiff(isautojacvec.mode, Enzyme.Const(gclosure2), Enzyme.Duplicated(f, _tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Duplicated(ytmp, tmp1), dup, Enzyme.Const(t), Enzyme.Const(W)) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 2c176d85f..a3270e0c7 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -441,6 +441,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing) sensealg, dgdp_cache, dgdp) end +function g(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end + # out = λ df(u, p, t)/dp at u=y, p=p, t=t function vec_pjac!(out, λ, y, t, S::GaussIntegrand) (; pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol) = S @@ -500,17 +505,13 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) else - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - tmp6 = Enzyme.make_zero(g) + tmp6 = Enzyme.make_zero(f) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Const(gclosure3), Enzyme.Duplicated(f, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 696023e9c..45bc80ef0 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -232,6 +232,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) AdjointSensitivityIntegrand(sol, adj_sol, p, y, λ, pf, f_cache, pJ, paramjac_config, sensealg, dgdp_cache, dgdp) end + +function gclosure4(f, du, u, p, t) + Base.copyto!(du, f(u, p, t)) + nothing +end # out = λ df(u, p, t)/dp at u=y, p=p, t=t function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) @@ -295,17 +300,13 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) if SciMLBase.isinplace(sol.prob.f) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( - Enzyme.Reverse, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Duplicated(SciMLBase.Void(f), tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), dup, Enzyme.Const(t)) else - function g(du, u, p, t) - du .= f(u, p, t) - nothing - end - tmp6 = Enzyme.make_zero(g) + tmp6 = Enzyme.make_zero(f) Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.Reverse), Enzyme.Duplicated(g, tmp6), Enzyme.Const, + sensealg.autojacvec.mode, Enzyme.Const(gclosure4), Enzyme.Duplicated(f, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), Enzyme.Const(y), dup, Enzyme.Const(t)) end diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index f63b8f0ec..ed5a104c9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1306,7 +1306,7 @@ like BLAS/LAPACK are used) and this will be the most efficient adjoint implement ## Constructor ```julia -EnzymeVJP(; chunksize = 0) +EnzymeVJP(; chunksize = 0, mode = EnzymeCore.Reverse) ``` ## Keyword Arguments @@ -1317,12 +1317,15 @@ EnzymeVJP(; chunksize = 0) should be set to the maximum chunksize that can occur during an integration to preallocate the `DualCaches` for PreallocationTools.jl. It defaults to 0, using `ForwardDiff.pickchunksize` but could be decreased if this value is known to be lower to conserve memory. + - `mode`: the parameterized Enzyme mode, default set to EnzymeCore.Reverse. Alternatively one + may want to pass Enzyme.set_runtime_activity(Enzyme.Reverse) """ -struct EnzymeVJP <: VJPChoice +struct EnzymeVJP{Mode<:Enzyme.ReverseMode} <: VJPChoice chunksize::Int + mode::Mode end -EnzymeVJP(; chunksize = 0) = EnzymeVJP(chunksize) +EnzymeVJP(; chunksize = 0, mode = Enzyme.Reverse) = EnzymeVJP(chunksize, mode) """ ```julia From 4fdc134ed811a62e223afad48aaa6977e85c8f44 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 11 Nov 2025 21:47:08 -0600 Subject: [PATCH 2/7] Update concrete_solve.jl --- src/concrete_solve.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 6836797a5..76b833c9f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -11,6 +11,7 @@ function adfunc(out, u, _p, t, repack) nothing end + function inplace_vjp(prob, u0, p, verbose, repack) du = zero(u0) From 0dd67d2671aebfeaf13b7e2e8a23acbfb49eaa5c Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 11 Nov 2025 23:55:56 -0600 Subject: [PATCH 3/7] Remove empty line before inplace_vjp function --- src/concrete_solve.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 76b833c9f..6836797a5 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -11,7 +11,6 @@ function adfunc(out, u, _p, t, repack) nothing end - function inplace_vjp(prob, u0, p, verbose, repack) du = zero(u0) From 050fc644ca8b507033fca8ec533bb42af869dbbc Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 12 Nov 2025 11:13:47 -0600 Subject: [PATCH 4/7] Update concrete_solve.jl --- src/concrete_solve.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 6836797a5..0f6f4d58c 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -5,7 +5,10 @@ const have_not_warned_vjp = Ref(true) const STACKTRACE_WITH_VJPWARN = Ref(false) - + +using Enzyme +Enzyme.Compiler.VERBOSE_ERRORS[] = true + function adfunc(out, u, _p, t, repack) f(out, u, repack(_p), t) nothing From 56878cc4fabc7414205769e0ca097eacadafba51 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 12 Nov 2025 12:15:18 -0600 Subject: [PATCH 5/7] Add verbose error reporting for Enzyme Enable verbose error reporting for Enzyme. --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index cfb90641e..7ba91274e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,9 @@ function activate_gpu_env() Pkg.instantiate() end +using Enzyme +Enzyme.Compiler.VERBOSE_ERRORS[] = true + @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin From f1869c96359a051f4e887a54c3b81afa04381429 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 12 Nov 2025 12:15:44 -0600 Subject: [PATCH 6/7] Add Enzyme with verbose error reporting Enable verbose error reporting for Enzyme. --- test/concrete_solve_derivatives.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index 0a7b14c4a..c70a4c4c4 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -2,6 +2,9 @@ using SciMLSensitivity, OrdinaryDiffEq, Zygote using Test, ForwardDiff import Tracker, ReverseDiff, ChainRulesCore, Mooncake, Enzyme +using Enzyme +Enzyme.Compiler.VERBOSE_ERRORS[] = true + function fiip(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] From a6b59f621a8e598e8fee46c6d7fb56dede4ba074 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 13 Nov 2025 15:08:10 -0600 Subject: [PATCH 7/7] fix --- src/concrete_solve.jl | 3 --- test/concrete_solve_derivatives.jl | 3 --- test/runtests.jl | 3 --- 3 files changed, 9 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 0f6f4d58c..447ed182c 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -6,9 +6,6 @@ const have_not_warned_vjp = Ref(true) const STACKTRACE_WITH_VJPWARN = Ref(false) -using Enzyme -Enzyme.Compiler.VERBOSE_ERRORS[] = true - function adfunc(out, u, _p, t, repack) f(out, u, repack(_p), t) nothing diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index c70a4c4c4..0a7b14c4a 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -2,9 +2,6 @@ using SciMLSensitivity, OrdinaryDiffEq, Zygote using Test, ForwardDiff import Tracker, ReverseDiff, ChainRulesCore, Mooncake, Enzyme -using Enzyme -Enzyme.Compiler.VERBOSE_ERRORS[] = true - function fiip(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2] diff --git a/test/runtests.jl b/test/runtests.jl index 7ba91274e..cfb90641e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,9 +10,6 @@ function activate_gpu_env() Pkg.instantiate() end -using Enzyme -Enzyme.Compiler.VERBOSE_ERRORS[] = true - @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin