|
| 1 | +module GeometryOptimizationOptimizationExt |
| 2 | +using AtomsCalculators |
| 3 | +using Optimization |
| 4 | +using GeometryOptimization |
| 5 | +import GeometryOptimization: GeoOptProblem, GeoOptConvergence |
| 6 | +const GO = GeometryOptimization |
| 7 | + |
| 8 | +function GeometryOptimization.solve_problem( |
| 9 | + prob::GeoOptProblem, solver, cvg::GeoOptConvergence; |
| 10 | + callback, maxiters, maxtime, kwargs...) |
| 11 | + |
| 12 | + function inner_callback(optim_state, ::Any, geoopt_state) |
| 13 | + cache_evaluations = geoopt_state.cache_evaluations |
| 14 | + |
| 15 | + # Find position in the cache matching the current state |
| 16 | + i_match = findlast(cache_evaluations) do eval |
| 17 | + isnothing(eval.gradnorm) && return false |
| 18 | + |
| 19 | + tol = 10eps(typeof(optim_state.objective)) |
| 20 | + obj_matches = abs(eval.objective - optim_state.objective) < tol |
| 21 | + |
| 22 | + if isnothing(optim_state.grad) |
| 23 | + # Nothing we can do, let's just hope it's ok |
| 24 | + grad_matches = true |
| 25 | + else |
| 26 | + g_norm = maximum(abs, optim_state.grad) |
| 27 | + grad_matches = abs(eval.gradnorm - g_norm) < tol |
| 28 | + end |
| 29 | + |
| 30 | + obj_matches && grad_matches |
| 31 | + end |
| 32 | + i_match = @something i_match length(cache_evaluations) |
| 33 | + |
| 34 | + # Commit data from state and discard the rest |
| 35 | + geoopt_state.n_iter = optim_state.iter |
| 36 | + push!(geoopt_state.history_energy, cache_evaluations[i_match].energy) |
| 37 | + if !isnothing(cache_evaluations[i_match].forces) |
| 38 | + geoopt_state.forces .= cache_evaluations[i_match].forces |
| 39 | + end |
| 40 | + if !isnothing(cache_evaluations[i_match].virial) |
| 41 | + geoopt_state.virial .= cache_evaluations[i_match].virial |
| 42 | + end |
| 43 | + empty!(cache_evaluations) |
| 44 | + |
| 45 | + # Check for convergence |
| 46 | + geoopt_state.converged = GO.is_converged(cvg, geoopt_state) |
| 47 | + |
| 48 | + # Callback and possible abortion |
| 49 | + halt = callback(optim_state, geoopt_state) |
| 50 | + halt && return true |
| 51 | + |
| 52 | + geoopt_state.converged |
| 53 | + end |
| 54 | + |
| 55 | + optimres = solve(Optimization.OptimizationProblem(prob), solver; |
| 56 | + maxiters, maxtime, callback=inner_callback, kwargs...) |
| 57 | + (; minimizer=optimres.u, minimum=optimres.objective, optimres) |
| 58 | +end |
| 59 | + |
| 60 | + |
| 61 | +function Optimization.OptimizationProblem(prob::GeoOptProblem; kwargs...) |
| 62 | + f = function(x::AbstractVector{<:Real}, ps) |
| 63 | + GO.eval_objective_gradient!(nothing, prob, ps, x), prob.geoopt_state |
| 64 | + end |
| 65 | + g! = function(G::AbstractVector{<:Real}, x::AbstractVector{<:Real}, ps) |
| 66 | + GO.eval_objective_gradient!(G, prob, ps, x), prob.geoopt_state |
| 67 | + G |
| 68 | + end |
| 69 | + f_opt = OptimizationFunction(f; grad=g!) |
| 70 | + |
| 71 | + # Note: Some optimisers modify Dofs x0 in-place, so x0 needs to be mutable type. |
| 72 | + x0 = GO.get_dofs(prob.system, prob.dofmgr) |
| 73 | + OptimizationProblem(f_opt, x0, AtomsCalculators.get_parameters(prob.calculator); |
| 74 | + sense=Optimization.MinSense, kwargs...) |
| 75 | +end |
| 76 | +end |
0 commit comments