From 4b93dc5d39c3491dbb1aee1374f2678225254f7f Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Thu, 6 Nov 2025 06:55:13 -0500 Subject: [PATCH] Fix GaussAdjoint with ZygoteVJP for in-place ODE functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #1282 Previously, `GaussAdjoint` with `ZygoteVJP()` would fail when used with in-place ODE functions, throwing a `MethodError` because it was calling the in-place function `f(du, u, p, t)` with only 3 arguments as if it were out-of-place `f(u, p, t)`. This fix: - Checks if the ODE function is in-place using `SciMLBase.isinplace()` - For in-place functions, creates a `Zygote.Buffer` to allow mutation during the forward pass while remaining differentiable - For out-of-place functions, keeps the existing behavior The use of `Zygote.Buffer` enables Zygote to differentiate through in-place functions by allowing controlled mutation during the forward pass and returning an immutable copy for the backward pass. Added comprehensive tests for both in-place and out-of-place ODE functions with `GaussAdjoint(autojacvec = ZygoteVJP())`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/gauss_adjoint.jl | 14 ++++++- test/gauss_zygote_inplace.jl | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 test/gauss_zygote_inplace.jl diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 2c176d85f..953c82f53 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -480,8 +480,18 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) ReverseDiff.reverse_pass!(tape) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(tunables) do tunables - vec(f(y, repack(tunables), t)) + if SciMLBase.isinplace(sol.prob.f) + # For in-place functions, create an out-of-place wrapper using Zygote.Buffer + # to allow mutation during the forward pass while remaining differentiable + _dy, back = Zygote.pullback(tunables) do tunables + du_buf = Zygote.Buffer(y) + f(du_buf, y, repack(tunables), t) + vec(copy(du_buf)) + end + else + _dy, back = Zygote.pullback(tunables) do tunables + vec(f(y, repack(tunables), t)) + end end tmp = back(λ) if tmp[1] === nothing diff --git a/test/gauss_zygote_inplace.jl b/test/gauss_zygote_inplace.jl new file mode 100644 index 000000000..fb750a2ef --- /dev/null +++ b/test/gauss_zygote_inplace.jl @@ -0,0 +1,73 @@ +using SciMLSensitivity, DifferentialEquations, Zygote +using Test + +# Test for issue #1282: GaussAdjoint with ZygoteVJP should handle in-place ODE functions +@testset "GaussAdjoint with ZygoteVJP and in-place ODE" begin + function fiip(du, u, p, t) + du[1] = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = -p[3] * u[2] + p[4] * u[1] * u[2] + return nothing + end + + p = [1.5, 1.0, 3.0, 1.0] + u0 = [1.0; 1.0] + prob = ODEProblem{true, SciMLBase.FullSpecialize}( + ODEFunction{true, SciMLBase.FullSpecialize}(fiip), + u0, (0.0, 10.0), p + ) + + # Test that basic solve works + sol = solve(prob, KenCarp4(), sensealg = GaussAdjoint()) + @test sol.retcode == ReturnCode.Success + + # Test that gradient computation with ZygoteVJP works + loss(u0, + p) = sum(solve( + prob, KenCarp4(), u0 = u0, p = p, saveat = 0.1, + sensealg = GaussAdjoint(autojacvec = ZygoteVJP()) + )) + + # This should not throw MethodError anymore + du0, dp = Zygote.gradient(loss, u0, p) + + @test du0 !== nothing + @test dp !== nothing + @test length(du0) == 2 + @test length(dp) == 4 + + # Test with explicit ZygoteVJP specification + (dp2,) = Zygote.gradient(p) do p + sum(solve(prob, KenCarp4(), p = p, saveat = 0.1, + sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))) + end + + @test dp2 !== nothing + @test length(dp2) == 4 +end + +# Test out-of-place still works +@testset "GaussAdjoint with ZygoteVJP and out-of-place ODE" begin + function foop(u, p, t) + dx = p[1] * u[1] - p[2] * u[1] * u[2] + dy = -p[3] * u[2] + p[4] * u[1] * u[2] + [dx, dy] + end + + p = [1.5, 1.0, 3.0, 1.0] + u0 = [1.0; 1.0] + prob = ODEProblem(foop, u0, (0.0, 10.0), p) + + # Test that gradient computation with ZygoteVJP works for out-of-place + loss(u0, + p) = sum(solve( + prob, Tsit5(), u0 = u0, p = p, saveat = 0.1, + sensealg = GaussAdjoint(autojacvec = ZygoteVJP()) + )) + + du0, dp = Zygote.gradient(loss, u0, p) + + @test du0 !== nothing + @test dp !== nothing + @test length(du0) == 2 + @test length(dp) == 4 +end