Skip to content

Commit 9580452

Browse files
authored
Fix implementation of NLSolversBase API (#1213) (#1216)
(cherry picked from commit 2f1e619)
1 parent f9fe222 commit 9580452

File tree

3 files changed

+61
-56
lines changed

3 files changed

+61
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Optim"
22
uuid = "429524aa-4258-5aef-a3af-852621145aeb"
3-
version = "1.14.0"
3+
version = "2.0.0-dev"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/Manifolds.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ function NLSolversBase.gradient(obj::ManifoldObjective, i::Int)
4141
end
4242
function NLSolversBase.gradient!(obj::ManifoldObjective, x)
4343
xin = retract(obj.manifold, x)
44-
gradient!(obj.inner_obj, xin)
45-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
46-
return gradient(obj.inner_obj)
44+
g_x = gradient!(obj.inner_obj, xin)
45+
project_tangent!(obj.manifold, g_x, xin)
46+
return g_x
4747
end
4848
function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4949
xin = retract(obj.manifold, x)
50-
value_gradient!(obj.inner_obj, xin)
51-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
52-
return value(obj.inner_obj)
50+
f_x, g_x = value_gradient!(obj.inner_obj, xin)
51+
project_tangent!(obj.manifold, g_x, xin)
52+
return f_x, g_x
5353
end
5454

5555
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""

src/multivariate/solvers/constrained/fminbox.jl

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import NLSolversBase:
1+
using NLSolversBase:
22
value, value!, value!!, gradient, gradient!, value_gradient!, value_gradient!!
33
####### FIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIX THE MIDDLE OF BOX CASE THAT WAS THERE
44
mutable struct BarrierWrapper{TO,TB,Tm,TF,TDF} <: AbstractObjective
@@ -52,6 +52,9 @@ function _barrier_term_value(x::T, l, u) where {T}
5252
vu = ifelse(isfinite(dxu), -log(dxu), T(0))
5353
return vl + vu
5454
end
55+
_barrier_value(bb::BoxBarrier, x) =
56+
mapreduce(x -> _barrier_term_value(x...), +, zip(x, bb.lower, bb.upper))
57+
5558
function _barrier_term_gradient(x::T, l, u) where {T}
5659
dxl = x - l
5760
dxu = u - x
@@ -64,72 +67,74 @@ function _barrier_term_gradient(x::T, l, u) where {T}
6467
end
6568
return g
6669
end
67-
function value_gradient!(bb::BoxBarrier, g, x)
68-
g .= _barrier_term_gradient.(x, bb.lower, bb.upper)
69-
value(bb, x)
70-
end
71-
function gradient(bb::BoxBarrier, g, x)
72-
g = copy(g)
73-
g .= _barrier_term_gradient.(x, bb.lower, bb.upper)
74-
end
70+
7571
# Wrappers
76-
function value!!(bw::BarrierWrapper, x)
77-
bw.Fb = value(bw.b, x)
78-
bw.Ftotal = bw.mu * bw.Fb
72+
function NLSolversBase.value!!(bw::BarrierWrapper, x)
73+
bw.Fb = _barrier_value(bw.b, x)
7974
if in_box(bw, x)
80-
value!!(bw.obj, x)
81-
bw.Ftotal += value(bw.obj)
75+
F = value!!(bw.obj, x)
76+
bw.Ftotal = muladd(bw.mu, bw.Fb, F)
77+
else
78+
bw.Ftotal = bw.mu * bw.Fb
8279
end
80+
return bw.Ftotal
8381
end
84-
function value_gradient!!(bw::BarrierWrapper, x)
85-
bw.Fb = value(bw.b, x)
86-
bw.Ftotal = bw.mu * bw.Fb
82+
function NLSolversBase.value_gradient!!(bw::BarrierWrapper, x)
83+
bw.Fb = _barrier_value(bw.b, x)
8784
bw.DFb .= _barrier_term_gradient.(x, bw.b.lower, bw.b.upper)
88-
bw.DFtotal .= bw.mu .* bw.DFb
8985
if in_box(bw, x)
90-
value_gradient!!(bw.obj, x)
91-
bw.Ftotal += value(bw.obj)
92-
bw.DFtotal .+= gradient(bw.obj)
86+
F, DF = value_gradient!!(bw.obj, x)
87+
bw.Ftotal = muladd(bw.mu, bw.Fb, F)
88+
bw.DFtotal .= muladd.(bw.mu, bw.DFb, DF)
89+
else
90+
bw.Ftotal = bw.mu * bw.Fb
91+
bw.DFtotal .= bw.mu .* bw.DFb
9392
end
94-
93+
return bw.Ftotal, bw.DFtotal
9594
end
96-
function value_gradient!(bb::BarrierWrapper, x)
95+
function NLSolversBase.value_gradient!(bb::BarrierWrapper, x)
9796
bb.DFb .= _barrier_term_gradient.(x, bb.b.lower, bb.b.upper)
98-
bb.Fb = value(bb.b, x)
99-
bb.DFtotal .= bb.mu .* bb.DFb
100-
bb.Ftotal = bb.mu * bb.Fb
101-
97+
bb.Fb = _barrier_value(bb.b, x)
10298
if in_box(bb, x)
103-
value_gradient!(bb.obj, x)
104-
bb.DFtotal .+= gradient(bb.obj)
105-
bb.Ftotal += value(bb.obj)
99+
F, DF = value_gradient!(bb.obj, x)
100+
bb.DFtotal .= muladd.(bb.mu, bb.DFb, DF)
101+
bb.Ftotal = muladd(bb.mu, bb.Fb, F)
102+
else
103+
bb.DFtotal .= bb.mu .* bb.DFb
104+
bb.Ftotal = bb.mu * bb.Fb
106105
end
106+
return bb.Ftotal, bb.DFtotal
107107
end
108-
value(bb::BoxBarrier, x) =
109-
mapreduce(x -> _barrier_term_value(x...), +, zip(x, bb.lower, bb.upper))
110-
function value!(obj::BarrierWrapper, x)
111-
obj.Fb = value(obj.b, x)
112-
obj.Ftotal = obj.mu * obj.Fb
108+
function NLSolversBase.value!(obj::BarrierWrapper, x)
109+
obj.Fb = _barrier_value(obj.b, x)
113110
if in_box(obj, x)
114-
value!(obj.obj, x)
115-
obj.Ftotal += value(obj.obj)
111+
F = value!(obj.obj, x)
112+
obj.Ftotal = muladd(obj.mu, obj.Fb, F)
113+
else
114+
obj.Ftotal = obj.mu * obj.Fb
116115
end
117-
obj.Ftotal
116+
return obj.Ftotal
118117
end
119-
value(obj::BarrierWrapper) = obj.Ftotal
120-
function value(obj::BarrierWrapper, x)
121-
F = obj.mu * value(obj.b, x)
118+
NLSolversBase.value(obj::BarrierWrapper) = obj.Ftotal
119+
function NLSolversBase.value(obj::BarrierWrapper, x)
120+
Fb = _barrier_value(obj.b, x)
122121
if in_box(obj, x)
123-
F += value(obj.obj, x)
122+
return muladd(obj.mu, Fb, value(obj.obj, x))
123+
else
124+
return obj.mu * Fb
124125
end
125-
F
126126
end
127-
function gradient!(obj::BarrierWrapper, x)
128-
gradient!(obj.obj, x)
129-
obj.DFb .= gradient(obj.b, obj.DFb, x) # this should just be inplace?
130-
obj.DFtotal .= gradient(obj.obj) .+ obj.mu * obj.Fb
127+
function NLSolversBase.gradient!(obj::BarrierWrapper, x)
128+
obj.DFb .= _barrier_term_gradient.(x, obj.b.lower, obj.b.upper)
129+
if in_box(obj, x)
130+
DF = gradient!(obj.obj, x)
131+
obj.DFtotal .= muladd.(obj.mu, obj.DFb, DF)
132+
else
133+
obj.DFtotal .= obj.mu .* obj.DFb
134+
end
135+
return obj.DFtotal
131136
end
132-
gradient(obj::BarrierWrapper) = obj.DFtotal
137+
NLSolversBase.gradient(obj::BarrierWrapper) = obj.DFtotal
133138

134139
# this mutates mu but not the gradients
135140
# Super unsafe in that it depends on x_df being correct!
@@ -489,7 +494,7 @@ function optimize(
489494
if F.method isa NelderMead
490495
for i = 1:length(state.f_simplex)
491496
x = state.simplex[i]
492-
boxval = value(dfbox.b, x)
497+
boxval = _barrier_value(dfbox.b, x)
493498
state.f_simplex[i] += boxval
494499
end
495500
state.i_order = sortperm(state.f_simplex)

0 commit comments

Comments
 (0)