Skip to content

Commit 95ca2d8

Browse files
committed
[WIP] Do not (mis)use objective as state
1 parent 9290b55 commit 95ca2d8

34 files changed

+1025
-972
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ExplicitImports = "1.13.2"
3030
FillArrays = "0.6.2, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3131
ForwardDiff = "0.10, 1"
3232
JET = "0.9, 0.10"
33-
LineSearches = "7.4.0"
33+
LineSearches = "7.5.1"
3434
LinearAlgebra = "<0.0.1, 1.6"
3535
MathOptInterface = "1.17"
3636
Measurements = "2.14.1"

src/Manifolds.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,7 @@ end
2828
# TODO: is it safe here to call retract! and change x?
2929
function NLSolversBase.value!(obj::ManifoldObjective, x)
3030
xin = retract(obj.manifold, x)
31-
value!(obj.inner_obj, xin)
32-
end
33-
function NLSolversBase.value(obj::ManifoldObjective)
34-
value(obj.inner_obj)
35-
end
36-
function NLSolversBase.gradient(obj::ManifoldObjective)
37-
gradient(obj.inner_obj)
38-
end
39-
function NLSolversBase.gradient(obj::ManifoldObjective, i::Int)
40-
gradient(obj.inner_obj, i)
31+
return value!(obj.inner_obj, xin)
4132
end
4233
function NLSolversBase.gradient!(obj::ManifoldObjective, x)
4334
xin = retract(obj.manifold, x)

src/Optim.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ using NLSolversBase:
4141
TwiceDifferentiableConstraints,
4242
nconstraints,
4343
nconstraints_x,
44-
hessian,
4544
hessian!,
46-
hessian!!,
47-
hv_product,
4845
hv_product!
4946

5047
# var for NelderMead

src/multivariate/optimize/optimize.jl

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,42 @@
1-
update_g!(d, state, method) = nothing
2-
function update_g!(d, state, method::FirstOrderOptimizer)
3-
# Update the function value and gradient
4-
value_gradient!(d, state.x)
5-
project_tangent!(method.manifold, gradient(d), state.x)
1+
# Update function value, gradient and Hessian
2+
function update_fgh!(d, state, ::ZerothOrderOptimizer)
3+
f_x = value!(d, state.x)
4+
state.f_x = f_x
5+
return nothing
66
end
7-
function update_g!(d, state, method::Newton)
8-
# Update the function value and gradient
9-
value_gradient!(d, state.x)
10-
end
11-
update_fg!(d, state, method) = nothing
12-
update_fg!(d, state, method::ZerothOrderOptimizer) = value!(d, state.x)
13-
function update_fg!(d, state, method::FirstOrderOptimizer)
14-
value_gradient!(d, state.x)
15-
project_tangent!(method.manifold, gradient(d), state.x)
7+
function update_fgh!(d, state, method::FirstOrderOptimizer)
8+
f_x, g_x = value_gradient!(d, state.x)
9+
if hasproperty(method, :manifold)
10+
project_tangent!(method.manifold, g_x, state.x)
11+
end
12+
state.f_x = f_x
13+
copyto!(state.g_x, g_x)
14+
return nothing
1615
end
17-
function update_fg!(d, state, method::Newton)
18-
value_gradient!(d, state.x)
16+
function update_fgh!(d, state, method::SecondOrderOptimizer)
17+
# Manifold optimization is currently not supported for second order optimization algorithms
18+
@assert !hasproperty(method, :manifold)
19+
20+
# TODO: Switch to `value_gradient_hessian!` when it becomes available
21+
f_x, g_x = value_gradient!(d, state.x)
22+
H_x = hessian!(d, state.x)
23+
state.f_x = f_x
24+
copyto!(state.g_x, g_x)
25+
copyto!(state.H_x, H_x)
26+
27+
return nothing
1928
end
2029

21-
# Update the Hessian
22-
update_h!(d, state, method) = nothing
23-
update_h!(d, state, method::SecondOrderOptimizer) = hessian!(d, state.x)
24-
2530
after_while!(d, state, method, options) = nothing
2631

27-
function initial_convergence(d, state, method::AbstractOptimizer, initial_x, options)
28-
gradient!(d, initial_x)
29-
stopped = !isfinite(value(d)) || any(!isfinite, gradient(d))
30-
g_residual(d, state) <= options.g_abstol, stopped
32+
function initial_convergence(state::AbstractOptimizerState, options::Options)
33+
stopped = !isfinite(state.f_x) || any(!isfinite, state.g_x)
34+
return g_residual(state) <= options.g_abstol, stopped
3135
end
32-
function initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options)
36+
function initial_convergence(::ZerothOrderState, ::Options)
3337
false, false
3438
end
39+
3540
function optimize(
3641
d::D,
3742
initial_x::Tx,
@@ -41,7 +46,7 @@ function optimize(
4146
) where {D<:AbstractObjective,M<:AbstractOptimizer,Tx<:AbstractArray,T,TCallback}
4247

4348
t0 = time() # Initial time stamp used to control early stopping by options.time_limit
44-
tr = OptimizationTrace{typeof(value(d)),typeof(method)}()
49+
tr = OptimizationTrace{typeof(state.f_x),typeof(method)}()
4550
tracing =
4651
options.store_trace ||
4752
options.show_trace ||
@@ -51,7 +56,7 @@ function optimize(
5156
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
5257
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
5358

54-
g_converged, stopped = initial_convergence(d, state, method, initial_x, options)
59+
g_converged, stopped = initial_convergence(state, options)
5560
converged = g_converged || stopped
5661
# prepare iteration counter (used to make "initial state" trace entry)
5762
iteration = 0
@@ -62,22 +67,29 @@ function optimize(
6267
ls_success::Bool = true
6368
while !converged && !stopped && iteration < options.iterations
6469
iteration += 1
70+
71+
# Convention: When `update_state!` is called, then `state` satisfies:
72+
# - `state.x`: Current state
73+
# - `state.f`: Objective function value of the current state, ie. `d(state.x)`
74+
# - `state.g_x` (if available): Gradient of the objective function at the current state, i.e. `gradient(d, state.x)`
75+
# - `state.H_x` (if available): Hessian of the objective function at the current state, i.e. `hessian(d, state.x)`
6576
ls_success = !update_state!(d, state, method)
6677
if !ls_success
6778
break # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
6879
end
69-
if !(method isa NewtonTrustRegion)
70-
update_g!(d, state, method) # TODO: Should this be `update_fg!`?
71-
end
80+
81+
# Update function value, gradient and Hessian matrix (skipped by some methods that already update those in `update_state!`)
82+
# TODO: Already perform in `update_state!`?
83+
update_fgh!(d, state, method)
84+
85+
# Check convergence
7286
x_converged, f_converged, g_converged, f_increased =
7387
assess_convergence(state, d, options)
7488
# For some problems it may be useful to require `f_converged` to be hit multiple times
7589
# TODO: Do the same for x_tol?
7690
counter_f_tol = f_converged ? counter_f_tol + 1 : 0
7791
converged = x_converged || g_converged || (counter_f_tol > options.successive_f_tol)
78-
if !(converged && method isa Newton) && !(method isa NewtonTrustRegion)
79-
update_h!(d, state, method) # only relevant if not converged
80-
end
92+
8193
if tracing
8294
# update trace; callbacks can stop routine early by returning true
8395
stopped_by_callback =
@@ -113,11 +125,11 @@ function optimize(
113125
end
114126
end
115127

116-
if g_calls(d) > 0 && !all(isfinite, gradient(d))
128+
if hasproperty(state, :g_x) && !all(isfinite, state.g_x)
117129
options.show_warnings && @warn "Terminated early due to NaN in gradient."
118130
break
119131
end
120-
if h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
132+
if hasproperty(state, :H_x) && !all(isfinite, state.H_x)
121133
options.show_warnings && @warn "Terminated early due to NaN in Hessian."
122134
break
123135
end
@@ -127,7 +139,7 @@ function optimize(
127139

128140
# we can just check minimum, as we've earlier enforced same types/eltypes
129141
# in variables besides the option settings
130-
Tf = typeof(value(d))
142+
Tf = typeof(state.f_x)
131143
f_incr_pick = f_increased && !options.allow_f_increases
132144
stopped_by = (x_converged, f_converged, g_converged,
133145
f_limit_reached = f_limit_reached,
@@ -141,7 +153,7 @@ function optimize(
141153
)
142154

143155
termination_code =
144-
_termination_code(d, g_residual(d, state), state, stopped_by, options)
156+
_termination_code(d, g_residual(state), state, stopped_by, options)
145157

146158
return MultivariateOptimizationResults{
147159
typeof(method),
@@ -154,18 +166,18 @@ function optimize(
154166
method,
155167
initial_x,
156168
pick_best_x(f_incr_pick, state),
157-
pick_best_f(f_incr_pick, state, d),
169+
pick_best_f(f_incr_pick, state),
158170
iteration,
159171
Tf(options.x_abstol),
160172
Tf(options.x_reltol),
161173
x_abschange(state),
162174
x_relchange(state),
163175
Tf(options.f_abstol),
164176
Tf(options.f_reltol),
165-
f_abschange(d, state),
166-
f_relchange(d, state),
177+
f_abschange(state),
178+
f_relchange(state),
167179
Tf(options.g_abstol),
168-
g_residual(d, state),
180+
g_residual(state),
169181
tr,
170182
f_calls(d),
171183
g_calls(d),
@@ -186,13 +198,13 @@ function _termination_code(d, gres, state, stopped_by, options)
186198
elseif (iszero(options.x_abstol) && x_abschange(state) <= options.x_abstol) ||
187199
(iszero(options.x_reltol) && x_relchange(state) <= options.x_reltol)
188200
TerminationCode.NoXChange
189-
elseif (iszero(options.f_abstol) && f_abschange(d, state) <= options.f_abstol) ||
190-
(iszero(options.f_reltol) && f_relchange(d, state) <= options.f_reltol)
201+
elseif (iszero(options.f_abstol) && f_abschange(state) <= options.f_abstol) ||
202+
(iszero(options.f_reltol) && f_relchange(state) <= options.f_reltol)
191203
TerminationCode.NoObjectiveChange
192204
elseif x_abschange(state) <= options.x_abstol || x_relchange(state) <= options.x_reltol
193205
TerminationCode.SmallXChange
194-
elseif f_abschange(d, state) <= options.f_abstol ||
195-
f_relchange(d, state) <= options.f_reltol
206+
elseif f_abschange(state) <= options.f_abstol ||
207+
f_relchange(state) <= options.f_reltol
196208
TerminationCode.SmallObjectiveChange
197209
elseif stopped_by.ls_failed
198210
TerminationCode.FailedLinesearch
@@ -210,11 +222,11 @@ function _termination_code(d, gres, state, stopped_by, options)
210222
TerminationCode.HessianCalls
211223
elseif stopped_by.f_increased
212224
TerminationCode.ObjectiveIncreased
213-
elseif f_calls(d) > 0 && !isfinite(value(d))
214-
TerminationCode.GradientNotFinite
215-
elseif g_calls(d) > 0 && !all(isfinite, gradient(d))
225+
elseif !isfinite(state.f_x)
226+
TerminationCode.ObjectiveNotFinite
227+
elseif hasproperty(state, :g_x) && !all(isfinite, state.g_x)
216228
TerminationCode.GradientNotFinite
217-
elseif h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
229+
elseif hasproperty(state, :H_x) && !all(isfinite, state.H_x)
218230
TerminationCode.HessianNotFinite
219231
else
220232
TerminationCode.NotImplemented

0 commit comments

Comments
 (0)