From 02bc14d83fc1614b484e536d98c5793087b4ab37 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sun, 7 Sep 2025 23:29:02 -0400 Subject: [PATCH 1/5] Add Enzyme inactive rules for VJP choice types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds proper Enzyme extension for SciMLSensitivity to make VJP choice types inactive during differentiation. This complements the AbstractSensitivityAlgorithm rule in SciMLBase. Fixes SciMLSensitivity.jl#1225: When `sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())` is passed to ODEProblem constructor, Enzyme would fail with "Error handling recursive stores for String" because it tried to differentiate through the VJP choice objects. Changes: - Add ext/SciMLSensitivityEnzymeExt.jl with VJPChoice inactive rule - Add comprehensive tests in test/enzyme_vjp_inactive.jl - Tests include original failing case from issue #1225 This works in conjunction with SciMLBase PR that adds AbstractSensitivityAlgorithm inactive rule to avoid type piracy. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 1 + ext/SciMLSensitivityEnzymeExt.jl | 19 +++++++++ test/enzyme_vjp_inactive.jl | 71 ++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 92 insertions(+) create mode 100644 ext/SciMLSensitivityEnzymeExt.jl create mode 100644 test/enzyme_vjp_inactive.jl diff --git a/Project.toml b/Project.toml index 74342aa99..f6584bf9b 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] SciMLSensitivityMooncakeExt = "Mooncake" +SciMLSensitivityEnzymeExt = "Enzyme" [compat] ADTypes = "1.9" diff --git a/ext/SciMLSensitivityEnzymeExt.jl b/ext/SciMLSensitivityEnzymeExt.jl new file mode 100644 index 000000000..072bd9bf6 --- /dev/null +++ b/ext/SciMLSensitivityEnzymeExt.jl @@ -0,0 +1,19 @@ +module SciMLSensitivityEnzymeExt + +using SciMLSensitivity +import Enzyme: EnzymeRules + +# Enzyme rules for VJP choice types defined in SciMLSensitivity +# +# VJP choice types configure how jacobian-vector products are computed within +# sensitivity algorithms. They should be treated as inactive (constant) during +# Enzyme differentiation to prevent errors when they are stored in problem +# structures or other data that Enzyme differentiates through. +# +# Note: AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase +# to avoid type piracy. + +# VJP choice types should be inactive since they configure computation methods +EnzymeRules.inactive_type(::Type{<:SciMLSensitivity.VJPChoice}) = true + +end \ No newline at end of file diff --git a/test/enzyme_vjp_inactive.jl b/test/enzyme_vjp_inactive.jl new file mode 100644 index 000000000..621a4739c --- /dev/null +++ b/test/enzyme_vjp_inactive.jl @@ -0,0 +1,71 @@ +using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq + +# Test that VJP choice types are treated as inactive by Enzyme +# The AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase +# This addresses issue #1225 where sensealg in ODEProblem constructor would fail + +@testset "Enzyme VJP Choice Inactive Types" begin + + # Test 1: Basic test that VJP objects can be stored in data structures during Enzyme differentiation + @testset "VJP types in data structures" begin + vjp = EnzymeVJP() + + function test_func(x) + # Store the VJP in a data structure (this would fail without inactive rules) + data = (value=x[1] + x[2], vjp=vjp) + return data.value * 2.0 + end + + x = [1.0, 2.0] + dx = Enzyme.make_zero(x) + + # This should not throw an error + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) + @test dx ≈ [2.0, 2.0] + end + + # Test 2: Test different VJP choice types are inactive + @testset "Different VJP types inactive" begin + vjp_types = [EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), TrackerVJP()] + + for vjp in vjp_types + function test_func(x) + data = (value=x[1] * x[2], vjp=vjp) + return data.value + 1.0 + end + + x = [2.0, 3.0] + dx = Enzyme.make_zero(x) + + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) + end + end + + # Test 3: Test sensitivity algorithms with VJP choices (integration test) + # Note: This test also depends on SciMLBase having AbstractSensitivityAlgorithm as inactive + @testset "Sensitivity algorithms with VJP choices" begin + function f(du, u, p, t) + du[1] = -p[1] * u[1] + du[2] = p[2] * u[2] + end + + function loss_func(p) + u0 = [1.0, 2.0] + # Both VJP choice and sensitivity algorithm should be inactive + prob = ODEProblem(f, u0, (0.0, 0.1), p, sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())) + sol = solve(prob, Tsit5()) + return sol.u[end][1] + sol.u[end][2] + end + + p = [0.5, 1.5] + dp = Enzyme.make_zero(p) + + # This should not throw the "Error handling recursive stores for String" error + # This is the original failing case from issue #1225 + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, loss_func, Enzyme.Active, Enzyme.Duplicated(p, dp)) + + # Verify the gradient is computed (non-zero and finite) + @test all(isfinite, dp) + @test any(x -> abs(x) > 1e-10, dp) # At least one component should be non-trivial + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index cfb90641e..5ddde73c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,6 +40,7 @@ end @time @safetestset "Scalar u0" include("scalar_u.jl") @time @safetestset "Error Messages" include("error_messages.jl") @time @safetestset "Autodiff Events" include("autodiff_events.jl") + @time @safetestset "Enzyme VJP Inactive" include("enzyme_vjp_inactive.jl") end end From ae6c3ece481ea5701d1f8aff97cba4915ffc8a33 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 8 Sep 2025 00:01:39 -0400 Subject: [PATCH 2/5] Bump SciMLBase requirement to v2.117.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This version includes the Enzyme extension for AbstractSensitivityAlgorithm inactive_type rule, which is needed for this PR's VJP choice rules to work together to fix SciMLSensitivity.jl#1225. Corresponds to General registry PR #138113. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f6584bf9b..2998f1ba9 100644 --- a/Project.toml +++ b/Project.toml @@ -96,7 +96,7 @@ RecursiveArrayTools = "3.27.2" Reexport = "1.0" ReverseDiff = "1.15.1" SafeTestsets = "0.1.0" -SciMLBase = "2.103.1" +SciMLBase = "2.117.0" SciMLJacobianOperators = "0.1" SciMLStructures = "1.3" SparseArrays = "1.10" From 94be61e1aef6cafc84f96af42050924852988a38 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 8 Sep 2025 01:05:39 -0400 Subject: [PATCH 3/5] Apply SciMLStyle formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run JuliaFormatter with SciMLStyle on the new files to fix format check errors. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- ext/SciMLSensitivityEnzymeExt.jl | 2 +- test/enzyme_vjp_inactive.jl | 35 ++++++++++++++++---------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/ext/SciMLSensitivityEnzymeExt.jl b/ext/SciMLSensitivityEnzymeExt.jl index 072bd9bf6..eb5dd19ef 100644 --- a/ext/SciMLSensitivityEnzymeExt.jl +++ b/ext/SciMLSensitivityEnzymeExt.jl @@ -16,4 +16,4 @@ import Enzyme: EnzymeRules # VJP choice types should be inactive since they configure computation methods EnzymeRules.inactive_type(::Type{<:SciMLSensitivity.VJPChoice}) = true -end \ No newline at end of file +end diff --git a/test/enzyme_vjp_inactive.jl b/test/enzyme_vjp_inactive.jl index 621a4739c..d99025896 100644 --- a/test/enzyme_vjp_inactive.jl +++ b/test/enzyme_vjp_inactive.jl @@ -5,42 +5,42 @@ using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq # This addresses issue #1225 where sensealg in ODEProblem constructor would fail @testset "Enzyme VJP Choice Inactive Types" begin - + # Test 1: Basic test that VJP objects can be stored in data structures during Enzyme differentiation @testset "VJP types in data structures" begin vjp = EnzymeVJP() - + function test_func(x) # Store the VJP in a data structure (this would fail without inactive rules) - data = (value=x[1] + x[2], vjp=vjp) + data = (value = x[1] + x[2], vjp = vjp) return data.value * 2.0 end - + x = [1.0, 2.0] dx = Enzyme.make_zero(x) - + # This should not throw an error @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) @test dx ≈ [2.0, 2.0] end - + # Test 2: Test different VJP choice types are inactive @testset "Different VJP types inactive" begin vjp_types = [EnzymeVJP(), ZygoteVJP(), ReverseDiffVJP(), TrackerVJP()] - + for vjp in vjp_types function test_func(x) - data = (value=x[1] * x[2], vjp=vjp) + data = (value = x[1] * x[2], vjp = vjp) return data.value + 1.0 end - + x = [2.0, 3.0] dx = Enzyme.make_zero(x) - + @test_nowarn Enzyme.autodiff(Enzyme.Reverse, test_func, Enzyme.Active, Enzyme.Duplicated(x, dx)) end end - + # Test 3: Test sensitivity algorithms with VJP choices (integration test) # Note: This test also depends on SciMLBase having AbstractSensitivityAlgorithm as inactive @testset "Sensitivity algorithms with VJP choices" begin @@ -48,24 +48,25 @@ using Test, SciMLSensitivity, Enzyme, OrdinaryDiffEq du[1] = -p[1] * u[1] du[2] = p[2] * u[2] end - + function loss_func(p) u0 = [1.0, 2.0] # Both VJP choice and sensitivity algorithm should be inactive - prob = ODEProblem(f, u0, (0.0, 0.1), p, sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())) + prob = ODEProblem( + f, u0, (0.0, 0.1), p, sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP())) sol = solve(prob, Tsit5()) return sol.u[end][1] + sol.u[end][2] end - + p = [0.5, 1.5] dp = Enzyme.make_zero(p) - + # This should not throw the "Error handling recursive stores for String" error # This is the original failing case from issue #1225 @test_nowarn Enzyme.autodiff(Enzyme.Reverse, loss_func, Enzyme.Active, Enzyme.Duplicated(p, dp)) - + # Verify the gradient is computed (non-zero and finite) @test all(isfinite, dp) @test any(x -> abs(x) > 1e-10, dp) # At least one component should be non-trivial end -end \ No newline at end of file +end From e1254e042c53b1d55a07209bd17419f736e85051 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 8 Sep 2025 02:12:25 -0400 Subject: [PATCH 4/5] Replace extension with direct Enzyme rules file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since SciMLSensitivity directly depends on Enzyme, use direct include instead of extension for VJP choice inactive rules. Changes: - Remove ext/SciMLSensitivityEnzymeExt.jl - Add src/enzyme_rules.jl with VJPChoice inactive rule - Include in main SciMLSensitivity.jl module - Remove extension registration from Project.toml Note: This addresses the original "Error handling recursive stores for String" error from #1225, but there may be additional Enzyme compatibility issues with array operations as noted in recent issue comments. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 1 - src/SciMLSensitivity.jl | 1 + .../enzyme_rules.jl | 11 +++-------- 3 files changed, 4 insertions(+), 9 deletions(-) rename ext/SciMLSensitivityEnzymeExt.jl => src/enzyme_rules.jl (81%) diff --git a/Project.toml b/Project.toml index 2998f1ba9..d5872cfdb 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] SciMLSensitivityMooncakeExt = "Mooncake" -SciMLSensitivityEnzymeExt = "Enzyme" [compat] ADTypes = "1.9" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7a003b4d3..619425ad4 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -83,6 +83,7 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +include("enzyme_rules.jl") export extract_local_sensitivities diff --git a/ext/SciMLSensitivityEnzymeExt.jl b/src/enzyme_rules.jl similarity index 81% rename from ext/SciMLSensitivityEnzymeExt.jl rename to src/enzyme_rules.jl index eb5dd19ef..e11b6acd4 100644 --- a/ext/SciMLSensitivityEnzymeExt.jl +++ b/src/enzyme_rules.jl @@ -1,8 +1,3 @@ -module SciMLSensitivityEnzymeExt - -using SciMLSensitivity -import Enzyme: EnzymeRules - # Enzyme rules for VJP choice types defined in SciMLSensitivity # # VJP choice types configure how jacobian-vector products are computed within @@ -13,7 +8,7 @@ import Enzyme: EnzymeRules # Note: AbstractSensitivityAlgorithm inactive rule is handled in SciMLBase # to avoid type piracy. -# VJP choice types should be inactive since they configure computation methods -EnzymeRules.inactive_type(::Type{<:SciMLSensitivity.VJPChoice}) = true +import Enzyme: EnzymeRules -end +# VJP choice types should be inactive since they configure computation methods +EnzymeRules.inactive_type(::Type{<:VJPChoice}) = true From 57a0b60b99e503bac8c2da9f10d245916e2d9bbf Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 8 Sep 2025 08:00:40 -0400 Subject: [PATCH 5/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d5872cfdb..3241a9533 100644 --- a/Project.toml +++ b/Project.toml @@ -95,7 +95,7 @@ RecursiveArrayTools = "3.27.2" Reexport = "1.0" ReverseDiff = "1.15.1" SafeTestsets = "0.1.0" -SciMLBase = "2.117.0" +SciMLBase = "2.117.1" SciMLJacobianOperators = "0.1" SciMLStructures = "1.3" SparseArrays = "1.10"