Skip to content

Commit 654f65f

Browse files
SebastianM-Cclaude
andcommitted
Fix Lagrangian Hessian prototype dimensions in OptimizationZygoteExt
The Lagrangian Hessian prototype was incorrectly sized as (num_constraints × num_variables) instead of (num_variables × num_variables). This caused a `BoundsError` when computing the Lagrangian Hessian with more variables than constraints, as the prototype was used as a buffer for the n×n Hessian matrix. Changes: - Fix lag_hess_prototype initialization to use zeros(Bool, length(x), length(x)) - Add comprehensive tests to verify prototype dimensions and usability as a buffer This ensures the Lagrangian Hessian is always correctly sized as n×n regardless of the number of constraints, matching the mathematical definition of the Lagrangian Hessian as the second derivative with respect to the decision variables. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 764af82 commit 654f65f

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

lib/OptimizationBase/ext/OptimizationZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ function OptimizationBase.instantiate_function(
208208
lag_extras = prepare_hessian(
209209
lagrangian, soadtype, x, Constant(one(eltype(x))),
210210
Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false))
211-
lag_hess_prototype = zeros(Bool, num_cons, length(x))
211+
lag_hess_prototype = zeros(Bool, length(x), length(x))
212212

213213
function lag_h!(H::AbstractMatrix, θ, σ, λ)
214214
if σ == zero(eltype(θ))

lib/OptimizationBase/test/adtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,17 @@ optprob.cons_h(H3, x0)
257257
optprob.lag_h(H4, x0, σ, μ)
258258
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
259259

260+
# Test that the AD-generated lag_hess_prototype has correct dimensions
261+
@test !isnothing(optprob.lag_hess_prototype)
262+
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n
263+
264+
# Test that we can actually use it as a buffer (this would fail with the bug)
265+
if !isnothing(optprob.lag_hess_prototype)
266+
H_proto = similar(optprob.lag_hess_prototype, Float64)
267+
optprob.lag_h(H_proto, x0, σ, μ)
268+
@test H_proto σ * H2 + μ[1] * H3[1] rtol=1e-6
269+
end
270+
260271
G2 = Array{Float64}(undef, 2)
261272
H2 = Array{Float64}(undef, 2, 2)
262273

@@ -490,6 +501,17 @@ end
490501
optprob.lag_h(H4, x0, σ, μ)
491502
@test H4σ * H1 + sum.* H3) rtol=1e-6
492503

504+
# Test that the AD-generated lag_hess_prototype has correct dimensions
505+
@test !isnothing(optprob.lag_hess_prototype)
506+
@test size(optprob.lag_hess_prototype) == (length(x0), length(x0)) # Should be n×n, not num_cons×n
507+
508+
# Test that we can actually use it as a buffer (this would fail with the bug)
509+
if !isnothing(optprob.lag_hess_prototype)
510+
H_proto = similar(optprob.lag_hess_prototype, Float64)
511+
optprob.lag_h(H_proto, x0, σ, μ)
512+
@test H_proto σ * H1 + sum.* H3) rtol=1e-6
513+
end
514+
493515
G2 = Array{Float64}(undef, 2)
494516
H2 = Array{Float64}(undef, 2, 2)
495517

0 commit comments

Comments
 (0)