1- # This file contains the top level solve interface functionality moved from SciMLBase.jl
2- # These functions provide the core optimization solving interface
1+ # Skip the DiffEqBase handling
32
43struct IncompatibleOptimizerError <: Exception
54 err:: String
@@ -9,27 +8,109 @@ function Base.showerror(io::IO, e::IncompatibleOptimizerError)
98 print (io, e. err)
109end
1110
12- const OPTIMIZER_MISSING_ERROR_MESSAGE = """
13- Optimization algorithm not found. Either the chosen algorithm is not a valid solver
14- choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
15- Make sure that you have loaded an appropriate OptimizationBase.jl solver library, for example,
16- `solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
17- `solve(prob,Adam())` requires `using OptimizationOptimisers`.
11+ """
12+ ```julia
13+ solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
14+ args...; kwargs...)::OptimizationSolution
15+ ```
1816
19- For more information, see the OptimizationBase.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
20- """
17+ For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)
2118
22- struct OptimizerMissingError <: Exception
23- alg:: Any
19+ ## Keyword Arguments
20+
21+ The arguments to `solve` are common across all of the optimizers.
22+ These common arguments are:
23+
24+ - `maxiters`: the maximum number of iterations
25+ - `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
26+ - `abstol`: absolute tolerance in changes of the objective value
27+ - `reltol`: relative tolerance in changes of the objective value
28+ - `callback`: a callback function
29+
30+ Some optimizer algorithms have special keyword arguments documented in the
31+ solver portion of the documentation and their respective documentation.
32+ These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
33+ keyword arguments for the `local_method` of a global optimizer are passed as a
34+ `NamedTuple` to `local_options`.
35+
36+ Over time, we hope to cover more of these keyword arguments under the common interface.
37+
38+ A warning will be shown if a common argument is not implemented for an optimizer.
39+
40+ ## Callback Functions
41+
42+ The callback function `callback` is a function that is called after every optimizer
43+ step. Its signature is:
44+
45+ ```julia
46+ callback = (state, loss_val) -> false
47+ ```
48+
49+ where `state` is an `OptimizationState` and stores information for the current
50+ iteration of the solver and `loss_val` is loss/objective value. For more
51+ information about the fields of the `state` look at the `OptimizationState`
52+ documentation. The callback should return a Boolean value, and the default
53+ should be `false`, so the optimization stops if it returns `true`.
54+
55+ ### Callback Example
56+
57+ Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
58+ For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
59+ So we call the `predict` function within the callback again.
60+
61+ ```julia
62+ function predict(u)
63+ Array(solve(prob, Tsit5(), p = u))
2464end
2565
26- function Base. showerror (io:: IO , e:: OptimizerMissingError )
27- println (io, OPTIMIZER_MISSING_ERROR_MESSAGE)
28- print (io, " Chosen Optimizer: " )
29- print (e. alg)
66+ function loss(u, p)
67+ pred = predict(u)
68+ sum(abs2, batch .- pred)
69+ end
70+
71+ callback = function (state, l; doplot = false) #callback function to observe training
72+ display(l)
73+ # plot current prediction against data
74+ if doplot
75+ pred = predict(state.u)
76+ pl = scatter(t, ode_data[1, :], label = "data")
77+ scatter!(pl, t, pred[1, :], label = "prediction")
78+ display(plot(pl))
79+ end
80+ return false
81+ end
82+ ```
83+
84+ If the chosen method is a global optimizer that employs a local optimization
85+ method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
86+ from NLopt for an example. The common local optimizer arguments are:
87+
88+ - `local_method`: optimizer used for local optimization in global method
89+ - `local_maxiters`: the maximum number of iterations
90+ - `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
91+ - `local_abstol`: absolute tolerance in changes of the objective value
92+ - `local_reltol`: relative tolerance in changes of the objective value
93+ - `local_options`: `NamedTuple` of keyword arguments for local optimizer
94+ """
95+ function SciMLBase. solve (prob:: SciMLBase.OptimizationProblem , alg, args... ;
96+ kwargs... ):: SciMLBase.AbstractOptimizationSolution
97+ if SciMLBase. supports_opt_cache_interface (alg)
98+ SciMLBase. solve! (SciMLBase. init (prob, alg, args... ; kwargs... ))
99+ else
100+ if prob. u0 != = nothing && ! isconcretetype (eltype (prob. u0))
101+ throw (SciMLBase. NonConcreteEltypeError (eltype (prob. u0)))
102+ end
103+ _check_opt_alg (prob, alg; kwargs... )
104+ SciMLBase. __solve (prob, alg, args... ; kwargs... )
105+ end
106+ end
107+
108+ function SciMLBase. solve (
109+ prob:: SciMLBase.EnsembleProblem{T} , args... ; kwargs... ) where {T < :
110+ SciMLBase. OptimizationProblem}
111+ return SciMLBase. __solve (prob, args... ; kwargs... )
30112end
31113
32- # Algorithm compatibility checking function
33114function _check_opt_alg (prob:: SciMLBase.OptimizationProblem , alg; kwargs... )
34115 ! SciMLBase. allowsbounds (alg) && (! isnothing (prob. lb) || ! isnothing (prob. ub)) &&
35116 throw (IncompatibleOptimizerError (" The algorithm $(typeof (alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm." ))
@@ -61,18 +142,80 @@ function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
61142 return
62143end
63144
64- # Base solver dispatch functions (these will be extended by specific solver packages)
65- supports_opt_cache_interface (alg) = false
145+ const OPTIMIZER_MISSING_ERROR_MESSAGE = """
146+ Optimization algorithm not found. Either the chosen algorithm is not a valid solver
147+ choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
148+ Make sure that you have loaded an appropriate Optimization.jl solver library, for example,
149+ `solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
150+ `solve(prob,Adam())` requires `using OptimizationOptimisers`.
151+
152+ For more information, see the Optimization.jl documentation: <https://docs.sciml.ai/Optimization/stable/>.
153+ """
66154
67- function __solve (cache :: SciMLBase.AbstractOptimizationCache ) :: SciMLBase.AbstractOptimizationSolution
68- throw ( OptimizerMissingError (cache . opt))
155+ struct OptimizerMissingError <: Exception
156+ alg :: Any
69157end
70158
71- function __init (prob:: SciMLBase.OptimizationProblem , alg, args... ;
159+ function Base. showerror (io:: IO , e:: OptimizerMissingError )
160+ println (io, OPTIMIZER_MISSING_ERROR_MESSAGE)
161+ print (io, " Chosen Optimizer: " )
162+ print (e. alg)
163+ end
164+
165+ """
166+ ```julia
167+ init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
168+ ```
169+
170+ ## Keyword Arguments
171+
172+ The arguments to `init` are the same as to `solve` and common across all of the optimizers.
173+ These common arguments are:
174+
175+ - `maxiters` (the maximum number of iterations)
176+ - `maxtime` (the maximum of time the optimization runs for)
177+ - `abstol` (absolute tolerance in changes of the objective value)
178+ - `reltol` (relative tolerance in changes of the objective value)
179+ - `callback` (a callback function)
180+
181+ Some optimizer algorithms have special keyword arguments documented in the
182+ solver portion of the documentation and their respective documentation.
183+ These arguments can be passed as `kwargs...` to `init`.
184+
185+ See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
186+ """
187+ function SciMLBase. init (prob:: SciMLBase.OptimizationProblem , alg, args... ;
188+ kwargs... ):: SciMLBase.AbstractOptimizationCache
189+ if prob. u0 != = nothing && ! isconcretetype (eltype (prob. u0))
190+ throw (SciMLBase. NonConcreteEltypeError (eltype (prob. u0)))
191+ end
192+ _check_opt_alg (prob:: SciMLBase.OptimizationProblem , alg; kwargs... )
193+ cache = SciMLBase. __init (prob, alg, args... ; prob. kwargs... , kwargs... )
194+ return cache
195+ end
196+
197+ """
198+ ```julia
199+ solve!(cache::AbstractOptimizationCache)
200+ ```
201+
202+ Solves the given optimization cache.
203+
204+ See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
205+ """
206+ function SciMLBase. solve! (cache:: SciMLBase.AbstractOptimizationCache ):: SciMLBase.AbstractOptimizationSolution
207+ SciMLBase. __solve (cache)
208+ end
209+
210+ # needs to be defined for each cache
211+ SciMLBase. supports_opt_cache_interface (alg) = false
212+ function SciMLBase. __solve (cache:: SciMLBase.AbstractOptimizationCache ):: SciMLBase.AbstractOptimizationSolution end
213+ function SciMLBase. __init (prob:: SciMLBase.OptimizationProblem , alg, args... ;
72214 kwargs... ):: SciMLBase.AbstractOptimizationCache
73215 throw (OptimizerMissingError (alg))
74216end
75217
76- function __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
218+ # if no cache interface is supported at least the following method has to be defined
219+ function SciMLBase. __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
77220 throw (OptimizerMissingError (alg))
78- end
221+ end
0 commit comments