Skip to content

Commit c4d1eec

Browse files
committed
[WIP] Do not (mis)use objective as state
1 parent f9fe222 commit c4d1eec

32 files changed

+695
-635
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ExplicitImports = "1.13.2"
3030
FillArrays = "0.6.2, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3131
ForwardDiff = "0.10, 1"
3232
JET = "0.9, 0.10"
33-
LineSearches = "7.4.0"
33+
LineSearches = "7.5.1"
3434
LinearAlgebra = "<0.0.1, 1.6"
3535
MathOptInterface = "1.17"
3636
Measurements = "2.14.1"

src/Manifolds.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,19 @@ end
2828
# TODO: is it safe here to call retract! and change x?
2929
function NLSolversBase.value!(obj::ManifoldObjective, x)
3030
xin = retract(obj.manifold, x)
31-
value!(obj.inner_obj, xin)
32-
end
33-
function NLSolversBase.value(obj::ManifoldObjective)
34-
value(obj.inner_obj)
35-
end
36-
function NLSolversBase.gradient(obj::ManifoldObjective)
37-
gradient(obj.inner_obj)
38-
end
39-
function NLSolversBase.gradient(obj::ManifoldObjective, i::Int)
40-
gradient(obj.inner_obj, i)
31+
return value!(obj.inner_obj, xin)
4132
end
4233
function NLSolversBase.gradient!(obj::ManifoldObjective, x)
4334
xin = retract(obj.manifold, x)
44-
gradient!(obj.inner_obj, xin)
45-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
46-
return gradient(obj.inner_obj)
35+
g_xin = gradient!(obj.inner_obj, xin)
36+
project_tangent!(obj.manifold, g_xin, xin)
37+
return g_xin
4738
end
4839
function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4940
xin = retract(obj.manifold, x)
50-
value_gradient!(obj.inner_obj, xin)
51-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
52-
return value(obj.inner_obj)
41+
f_xin, g_xin = value_gradient!(obj.inner_obj, xin)
42+
project_tangent!(obj.manifold, g_xin, xin)
43+
return f_xin, g_xin
5344
end
5445

5546
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""

src/Optim.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ using NLSolversBase:
4141
TwiceDifferentiableConstraints,
4242
nconstraints,
4343
nconstraints_x,
44-
hessian,
4544
hessian!,
46-
hessian!!,
47-
hv_product,
4845
hv_product!
4946

5047
# var for NelderMead

src/multivariate/optimize/optimize.jl

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,57 @@
1-
update_g!(d, state, method) = nothing
1+
update_g!(d, state, ::ZerothOrderOptimizer) = nothing
22
function update_g!(d, state, method::FirstOrderOptimizer)
33
# Update the function value and gradient
4-
value_gradient!(d, state.x)
5-
project_tangent!(method.manifold, gradient(d), state.x)
4+
f_x, g_x = value_gradient!(d, state.x)
5+
project_tangent!(method.manifold, g_x, state.x)
6+
state.f_x = f_x
7+
copyto!(state.g_x, g_x)
8+
return nothing
69
end
7-
function update_g!(d, state, method::Newton)
10+
function update_g!(d, state, ::Newton)
811
# Update the function value and gradient
9-
value_gradient!(d, state.x)
12+
f_x, g_x = value_gradient!(d, state.x)
13+
state.f_x = f_x
14+
copyto!(state.g_x, g_x)
15+
return nothing
1016
end
11-
update_fg!(d, state, method) = nothing
12-
update_fg!(d, state, method::ZerothOrderOptimizer) = value!(d, state.x)
13-
function update_fg!(d, state, method::FirstOrderOptimizer)
14-
value_gradient!(d, state.x)
15-
project_tangent!(method.manifold, gradient(d), state.x)
17+
18+
function update_fg!(d, state, method::ZerothOrderOptimizer)
19+
f_x = value!(d, state.x)
20+
state.f_x = f_x
21+
return nothing
22+
end
23+
function update_fg!(d, state, method)
24+
f_x, g_x = value_gradient!(d, state.x)
25+
project_tangent!(method.manifold, g_x, state.x)
26+
state.f_x = f_x
27+
copyto!(state.g_x, g_x)
28+
return nothing
1629
end
1730
function update_fg!(d, state, method::Newton)
18-
value_gradient!(d, state.x)
31+
f_x, g_x = value_gradient!(d, state.x)
32+
state.f_x = f_x
33+
copyto!(state.g_x, g_x)
34+
return nothing
1935
end
2036

2137
# Update the Hessian
22-
update_h!(d, state, method) = nothing
23-
update_h!(d, state, method::SecondOrderOptimizer) = hessian!(d, state.x)
38+
update_h!(d, state, method::Union{ZerothOrderOptimizer,FirstOrderOptimizer}) = nothing
39+
function update_h!(d, state, method::SecondOrderOptimizer)
40+
H_x = hessian!(d, state.x)
41+
copyto!(state.H_x, H_x)
42+
return nothing
43+
end
2444

2545
after_while!(d, state, method, options) = nothing
2646

27-
function initial_convergence(d, state, method::AbstractOptimizer, initial_x, options)
28-
gradient!(d, initial_x)
29-
stopped = !isfinite(value(d)) || any(!isfinite, gradient(d))
30-
g_residual(d, state) <= options.g_abstol, stopped
47+
function initial_convergence(state::AbstractOptimizerState, options::Options)
48+
stopped = !isfinite(state.f_x) || any(!isfinite, state.g_x)
49+
return g_residual(state) <= options.g_abstol, stopped
3150
end
32-
function initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options)
51+
function initial_convergence(::ZerothOrderState, ::Options)
3352
false, false
3453
end
54+
3555
function optimize(
3656
d::D,
3757
initial_x::Tx,
@@ -41,7 +61,7 @@ function optimize(
4161
) where {D<:AbstractObjective,M<:AbstractOptimizer,Tx<:AbstractArray,T,TCallback}
4262

4363
t0 = time() # Initial time stamp used to control early stopping by options.time_limit
44-
tr = OptimizationTrace{typeof(value(d)),typeof(method)}()
64+
tr = OptimizationTrace{typeof(state.f_x),typeof(method)}()
4565
tracing =
4666
options.store_trace ||
4767
options.show_trace ||
@@ -51,7 +71,7 @@ function optimize(
5171
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
5272
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
5373

54-
g_converged, stopped = initial_convergence(d, state, method, initial_x, options)
74+
g_converged, stopped = initial_convergence(state, options)
5575
converged = g_converged || stopped
5676
# prepare iteration counter (used to make "initial state" trace entry)
5777
iteration = 0
@@ -113,11 +133,11 @@ function optimize(
113133
end
114134
end
115135

116-
if g_calls(d) > 0 && !all(isfinite, gradient(d))
136+
if hasproperty(state, :g_x) && !all(isfinite, state.g_x)
117137
options.show_warnings && @warn "Terminated early due to NaN in gradient."
118138
break
119139
end
120-
if h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
140+
if hasproperty(state, :H_x) && !all(isfinite, state.H_x)
121141
options.show_warnings && @warn "Terminated early due to NaN in Hessian."
122142
break
123143
end
@@ -127,7 +147,7 @@ function optimize(
127147

128148
# we can just check minimum, as we've earlier enforced same types/eltypes
129149
# in variables besides the option settings
130-
Tf = typeof(value(d))
150+
Tf = typeof(state.f_x)
131151
f_incr_pick = f_increased && !options.allow_f_increases
132152
stopped_by = (x_converged, f_converged, g_converged,
133153
f_limit_reached = f_limit_reached,
@@ -141,7 +161,7 @@ function optimize(
141161
)
142162

143163
termination_code =
144-
_termination_code(d, g_residual(d, state), state, stopped_by, options)
164+
_termination_code(d, g_residual(state), state, stopped_by, options)
145165

146166
return MultivariateOptimizationResults{
147167
typeof(method),
@@ -154,18 +174,18 @@ function optimize(
154174
method,
155175
initial_x,
156176
pick_best_x(f_incr_pick, state),
157-
pick_best_f(f_incr_pick, state, d),
177+
pick_best_f(f_incr_pick, state),
158178
iteration,
159179
Tf(options.x_abstol),
160180
Tf(options.x_reltol),
161181
x_abschange(state),
162182
x_relchange(state),
163183
Tf(options.f_abstol),
164184
Tf(options.f_reltol),
165-
f_abschange(d, state),
166-
f_relchange(d, state),
185+
f_abschange(state),
186+
f_relchange(state),
167187
Tf(options.g_abstol),
168-
g_residual(d, state),
188+
g_residual(state),
169189
tr,
170190
f_calls(d),
171191
g_calls(d),
@@ -186,13 +206,13 @@ function _termination_code(d, gres, state, stopped_by, options)
186206
elseif (iszero(options.x_abstol) && x_abschange(state) <= options.x_abstol) ||
187207
(iszero(options.x_reltol) && x_relchange(state) <= options.x_reltol)
188208
TerminationCode.NoXChange
189-
elseif (iszero(options.f_abstol) && f_abschange(d, state) <= options.f_abstol) ||
190-
(iszero(options.f_reltol) && f_relchange(d, state) <= options.f_reltol)
209+
elseif (iszero(options.f_abstol) && f_abschange(state) <= options.f_abstol) ||
210+
(iszero(options.f_reltol) && f_relchange(state) <= options.f_reltol)
191211
TerminationCode.NoObjectiveChange
192212
elseif x_abschange(state) <= options.x_abstol || x_relchange(state) <= options.x_reltol
193213
TerminationCode.SmallXChange
194-
elseif f_abschange(d, state) <= options.f_abstol ||
195-
f_relchange(d, state) <= options.f_reltol
214+
elseif f_abschange(state) <= options.f_abstol ||
215+
f_relchange(state) <= options.f_reltol
196216
TerminationCode.SmallObjectiveChange
197217
elseif stopped_by.ls_failed
198218
TerminationCode.FailedLinesearch
@@ -210,11 +230,11 @@ function _termination_code(d, gres, state, stopped_by, options)
210230
TerminationCode.HessianCalls
211231
elseif stopped_by.f_increased
212232
TerminationCode.ObjectiveIncreased
213-
elseif f_calls(d) > 0 && !isfinite(value(d))
214-
TerminationCode.GradientNotFinite
215-
elseif g_calls(d) > 0 && !all(isfinite, gradient(d))
233+
elseif !isfinite(state.f_x)
234+
TerminationCode.ObjectiveNotFinite
235+
elseif hasproperty(state, :g_x) && !all(isfinite, state.g_x)
216236
TerminationCode.GradientNotFinite
217-
elseif h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
237+
elseif hasproperty(state, :H_x) && !all(isfinite, state.H_x)
218238
TerminationCode.HessianNotFinite
219239
else
220240
TerminationCode.NotImplemented

0 commit comments

Comments
 (0)