1- update_g! (d, state, method) = nothing
2- function update_g ! (d, state, method :: FirstOrderOptimizer )
3- # Update the function value and gradient
4- value_gradient! (d, state. x)
5- project_tangent! (method . manifold, gradient (d), state . x)
1+ # Update function value, gradient and Hessian
2+ function update_fgh ! (d, state, :: ZerothOrderOptimizer )
3+ f_x = value! (d, state . x)
4+ state. f_x = f_x
5+ return nothing
66end
7- function update_g! (d, state, method:: Newton )
8- # Update the function value and gradient
9- value_gradient! (d, state. x)
10- end
11- update_fg! (d, state, method) = nothing
12- update_fg! (d, state, method:: ZerothOrderOptimizer ) = value! (d, state. x)
13- function update_fg! (d, state, method:: FirstOrderOptimizer )
14- value_gradient! (d, state. x)
15- project_tangent! (method. manifold, gradient (d), state. x)
7+ function update_fgh! (d, state, method:: FirstOrderOptimizer )
8+ f_x, g_x = value_gradient! (d, state. x)
9+ if hasproperty (method, :manifold )
10+ project_tangent! (method. manifold, g_x, state. x)
11+ end
12+ state. f_x = f_x
13+ copyto! (state. g_x, g_x)
14+ return nothing
1615end
17- function update_fg! (d, state, method:: Newton )
18- value_gradient! (d, state. x)
16+ function update_fgh! (d, state, method:: SecondOrderOptimizer )
17+ # Manifold optimization is currently not supported for second order optimization algorithms
18+ @assert ! hasproperty (method, :manifold )
19+
20+ # TODO : Switch to `value_gradient_hessian!` when it becomes available
21+ f_x, g_x = value_gradient! (d, state. x)
22+ H_x = hessian! (d, state. x)
23+ state. f_x = f_x
24+ copyto! (state. g_x, g_x)
25+ copyto! (state. H_x, H_x)
26+
27+ return nothing
1928end
2029
21- # Update the Hessian
22- update_h! (d, state, method) = nothing
23- update_h! (d, state, method:: SecondOrderOptimizer ) = hessian! (d, state. x)
24-
2530after_while! (d, state, method, options) = nothing
2631
27- function initial_convergence (d, state, method:: AbstractOptimizer , initial_x, options)
28- gradient! (d, initial_x)
29- stopped = ! isfinite (value (d)) || any (! isfinite, gradient (d))
30- g_residual (d, state) <= options. g_abstol, stopped
32+ function initial_convergence (state:: AbstractOptimizerState , options:: Options )
33+ stopped = ! isfinite (state. f_x) || any (! isfinite, state. g_x)
34+ return g_residual (state) <= options. g_abstol, stopped
3135end
32- function initial_convergence (d, state, method :: ZerothOrderOptimizer , initial_x, options )
36+ function initial_convergence (:: ZerothOrderState , :: Options )
3337 false , false
3438end
39+
3540function optimize (
3641 d:: D ,
3742 initial_x:: Tx ,
@@ -41,7 +46,7 @@ function optimize(
4146) where {D<: AbstractObjective ,M<: AbstractOptimizer ,Tx<: AbstractArray ,T,TCallback}
4247
4348 t0 = time () # Initial time stamp used to control early stopping by options.time_limit
44- tr = OptimizationTrace {typeof(value(d) ),typeof(method)} ()
49+ tr = OptimizationTrace {typeof(state.f_x ),typeof(method)} ()
4550 tracing =
4651 options. store_trace ||
4752 options. show_trace ||
@@ -51,7 +56,7 @@ function optimize(
5156 f_limit_reached, g_limit_reached, h_limit_reached = false , false , false
5257 x_converged, f_converged, f_increased, counter_f_tol = false , false , false , 0
5358
54- g_converged, stopped = initial_convergence (d, state, method, initial_x , options)
59+ g_converged, stopped = initial_convergence (state, options)
5560 converged = g_converged || stopped
5661 # prepare iteration counter (used to make "initial state" trace entry)
5762 iteration = 0
@@ -62,22 +67,29 @@ function optimize(
6267 ls_success:: Bool = true
6368 while ! converged && ! stopped && iteration < options. iterations
6469 iteration += 1
70+
71+ # Convention: When `update_state!` is called, then `state` satisfies:
72+ # - `state.x`: Current state
73+ # - `state.f`: Objective function value of the current state, ie. `d(state.x)`
74+ # - `state.g_x` (if available): Gradient of the objective function at the current state, i.e. `gradient(d, state.x)`
75+ # - `state.H_x` (if available): Hessian of the objective function at the current state, i.e. `hessian(d, state.x)`
6576 ls_success = ! update_state! (d, state, method)
6677 if ! ls_success
6778 break # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
6879 end
69- if ! (method isa NewtonTrustRegion)
70- update_g! (d, state, method) # TODO : Should this be `update_fg!`?
71- end
80+
81+ # Update function value, gradient and Hessian matrix (skipped by some methods that already update those in `update_state!`)
82+ # TODO : Already perform in `update_state!`?
83+ update_fgh! (d, state, method)
84+
85+ # Check convergence
7286 x_converged, f_converged, g_converged, f_increased =
7387 assess_convergence (state, d, options)
7488 # For some problems it may be useful to require `f_converged` to be hit multiple times
7589 # TODO : Do the same for x_tol?
7690 counter_f_tol = f_converged ? counter_f_tol + 1 : 0
7791 converged = x_converged || g_converged || (counter_f_tol > options. successive_f_tol)
78- if ! (converged && method isa Newton) && ! (method isa NewtonTrustRegion)
79- update_h! (d, state, method) # only relevant if not converged
80- end
92+
8193 if tracing
8294 # update trace; callbacks can stop routine early by returning true
8395 stopped_by_callback =
@@ -113,11 +125,11 @@ function optimize(
113125 end
114126 end
115127
116- if g_calls (d) > 0 && ! all (isfinite, gradient (d) )
128+ if hasproperty (state, :g_x ) && ! all (isfinite, state . g_x )
117129 options. show_warnings && @warn " Terminated early due to NaN in gradient."
118130 break
119131 end
120- if h_calls (d) > 0 && ! (d isa TwiceDifferentiableHV ) && ! all (isfinite, hessian (d) )
132+ if hasproperty (state, :H_x ) && ! all (isfinite, state . H_x )
121133 options. show_warnings && @warn " Terminated early due to NaN in Hessian."
122134 break
123135 end
@@ -127,7 +139,7 @@ function optimize(
127139
128140 # we can just check minimum, as we've earlier enforced same types/eltypes
129141 # in variables besides the option settings
130- Tf = typeof (value (d) )
142+ Tf = typeof (state . f_x )
131143 f_incr_pick = f_increased && ! options. allow_f_increases
132144 stopped_by = (x_converged, f_converged, g_converged,
133145 f_limit_reached = f_limit_reached,
@@ -141,7 +153,7 @@ function optimize(
141153 )
142154
143155 termination_code =
144- _termination_code (d, g_residual (d, state), state, stopped_by, options)
156+ _termination_code (d, g_residual (state), state, stopped_by, options)
145157
146158 return MultivariateOptimizationResults{
147159 typeof (method),
@@ -154,18 +166,18 @@ function optimize(
154166 method,
155167 initial_x,
156168 pick_best_x (f_incr_pick, state),
157- pick_best_f (f_incr_pick, state, d ),
169+ pick_best_f (f_incr_pick, state),
158170 iteration,
159171 Tf (options. x_abstol),
160172 Tf (options. x_reltol),
161173 x_abschange (state),
162174 x_relchange (state),
163175 Tf (options. f_abstol),
164176 Tf (options. f_reltol),
165- f_abschange (d, state),
166- f_relchange (d, state),
177+ f_abschange (state),
178+ f_relchange (state),
167179 Tf (options. g_abstol),
168- g_residual (d, state),
180+ g_residual (state),
169181 tr,
170182 f_calls (d),
171183 g_calls (d),
@@ -186,13 +198,13 @@ function _termination_code(d, gres, state, stopped_by, options)
186198 elseif (iszero (options. x_abstol) && x_abschange (state) <= options. x_abstol) ||
187199 (iszero (options. x_reltol) && x_relchange (state) <= options. x_reltol)
188200 TerminationCode. NoXChange
189- elseif (iszero (options. f_abstol) && f_abschange (d, state) <= options. f_abstol) ||
190- (iszero (options. f_reltol) && f_relchange (d, state) <= options. f_reltol)
201+ elseif (iszero (options. f_abstol) && f_abschange (state) <= options. f_abstol) ||
202+ (iszero (options. f_reltol) && f_relchange (state) <= options. f_reltol)
191203 TerminationCode. NoObjectiveChange
192204 elseif x_abschange (state) <= options. x_abstol || x_relchange (state) <= options. x_reltol
193205 TerminationCode. SmallXChange
194- elseif f_abschange (d, state) <= options. f_abstol ||
195- f_relchange (d, state) <= options. f_reltol
206+ elseif f_abschange (state) <= options. f_abstol ||
207+ f_relchange (state) <= options. f_reltol
196208 TerminationCode. SmallObjectiveChange
197209 elseif stopped_by. ls_failed
198210 TerminationCode. FailedLinesearch
@@ -210,11 +222,11 @@ function _termination_code(d, gres, state, stopped_by, options)
210222 TerminationCode. HessianCalls
211223 elseif stopped_by. f_increased
212224 TerminationCode. ObjectiveIncreased
213- elseif f_calls (d) > 0 && ! isfinite (value (d) )
214- TerminationCode. GradientNotFinite
215- elseif g_calls (d) > 0 && ! all (isfinite, gradient (d) )
225+ elseif ! isfinite (state . f_x )
226+ TerminationCode. ObjectiveNotFinite
227+ elseif hasproperty (state, :g_x ) && ! all (isfinite, state . g_x )
216228 TerminationCode. GradientNotFinite
217- elseif h_calls (d) > 0 && ! (d isa TwiceDifferentiableHV ) && ! all (isfinite, hessian (d) )
229+ elseif hasproperty (state, :H_x ) && ! all (isfinite, state . H_x )
218230 TerminationCode. HessianNotFinite
219231 else
220232 TerminationCode. NotImplemented
0 commit comments