diff --git a/benchmarks/adam_opt.jl b/benchmarks/adam_opt.jl index a347fb7..ae77416 100644 --- a/benchmarks/adam_opt.jl +++ b/benchmarks/adam_opt.jl @@ -58,7 +58,7 @@ optprob = Optimization.OptimizationProblem(optf, p_nn) @time res = Optimization.solve(optprob, ADAM(0.05), maxiters = 100) @show res.objective -@benchmark Optimization.solve(optprob, ADAM(0.05), maxiters = 100) +# @benchmark Optimization.solve(optprob, ADAM(0.05), maxiters = 100) ## PSOGPU stuff @@ -129,16 +129,16 @@ solver_cache = (; losses, gpu_particles, gpu_data, gbest) prob_func = prob_func, maxiters = 100) -@benchmark PSOGPU.parameter_estim_ode!($prob_nn, - $(deepcopy(solver_cache)), - $lb, - $ub; - saveat = tsteps, - dt = 0.1f0, - prob_func = prob_func, - maxiters = 100) +# @benchmark PSOGPU.parameter_estim_ode!($prob_nn, +# $(deepcopy(solver_cache)), +# $lb, +# $ub; +# saveat = tsteps, +# dt = 0.1f0, +# prob_func = prob_func, +# maxiters = 100) -@show gsol.best +@show gsol.cost using Plots @@ -151,7 +151,16 @@ using Plots plt = scatter(tsteps, data[1, :], label = "data") pred_pso = predict_neuralode((sc, gsol.position)) -scatter!(plt, tsteps, pred[1, :], label = "PSO prediction") +scatter!(plt, tsteps, pred_pso[1, :], label = "PSO prediction") pred_adam = predict_neuralode((sc, res.u)) scatter!(plt, tsteps, pred_adam[1, :], label = "Adam prediction") + +@time gsol1 = PSOGPU.parameter_estim_odehybrid(prob_nn, +solver_cache, +lb, +ub; +saveat = tsteps, +dt = 0.1f0, +prob_func = prob_func, +maxiters = 100) \ No newline at end of file diff --git a/src/PSOGPU.jl b/src/PSOGPU.jl index 99ce544..02c8859 100644 --- a/src/PSOGPU.jl +++ b/src/PSOGPU.jl @@ -1,6 +1,6 @@ module PSOGPU -using SciMLBase, StaticArrays, Setfield, KernelAbstractions +using SciMLBase, StaticArrays, Setfield, KernelAbstractions, OrdinaryDiffEq using QuasiMonteCarlo, Optimization, SimpleNonlinearSolve, ForwardDiff import Adapt import Enzyme: autodiff_deferred, Active, Reverse @@ -58,12 +58,12 @@ end include("./algorithms.jl") include("./utils.jl") -include("./ode_pso.jl") include("./kernels.jl") include("./lowerlevel_solve.jl") include("./solve.jl") include("./bfgs.jl") include("./hybrid.jl") +include("./ode_pso.jl") export ParallelPSOKernel, ParallelSyncPSOKernel, ParallelPSOArray, SerialPSO, OptimizationProblem, solve diff --git a/src/ode_pso.jl b/src/ode_pso.jl index 3fd03d5..55e04d1 100644 --- a/src/ode_pso.jl +++ b/src/ode_pso.jl @@ -85,3 +85,83 @@ function parameter_estim_ode!(prob::ODEProblem, cache, end return gbest end + +@inline function ode_loss(particle, second_arg; kwargs...) + prob, ode_alg = second_arg + prob = remake(prob, p = particle) + + us = OrdinaryDiffEq.__solve(prob, + ode_alg; kwargs...).u + us +end + +function parameter_estim_odehybrid(prob::ODEProblem, cache, + lb, + ub; + ode_alg = GPUTsit5(), + prob_func = default_prob_func, + w = 0.72980f0, + wdamp = 1.0f0, + maxiters = 100, + local_maxiters = 100, + abstol = nothing, + reltol = nothing, + kwargs...) + + (losses, gpu_particles, gpu_data, gbest) = cache + backend = get_backend(gpu_particles) + update_states! = PSOGPU._update_particle_states!(backend) + update_costs! = PSOGPU._update_particle_costs!(backend) + + improb = make_prob_compatible(prob) + + for i in 1:maxiters + update_states!(gpu_particles, + lb, + ub, + gbest, + w; + ndrange = length(gpu_particles)) + + probs = prob_func.(Ref(improb), gpu_particles) + + ts, us = vectorized_asolve(probs, + prob, + ode_alg; kwargs...) + + sum!(losses, (map(x -> sum(x .^ 2), gpu_data .- us))) + + update_costs!(losses, gpu_particles; ndrange = length(losses)) + + w = w * wdamp + end + + _f = Base.Fix2(ode_loss, (improb, Tsit5())) + # autodiff_deferred(Reverse, sum(abs2, gpu_data .- us), Active, Active(particle))[1][1] + f = x -> ForwardDiff.gradient(sum(abs2, gpu_data .- _f(x)), x) + + kernel = simplebfgs_run!(backend) + x0s = get_pos.(gpu_particles) + result = KernelAbstractions.allocate(backend, typeof(prob.u0), length(x0s)) + nlprob = NonlinearProblem{false}((x, p) -> f(x), prob.p) + + nlalg = SimpleBroyden(; linesearch = Val(true)) + + kernel(nlprob, + x0s, + result, + nlalg, + local_maxiters, + abstol, + reltol; + ndrange = length(x0s)) + + t1 = time() + sol_bfgs = (x -> ode_loss(x, (prob, Tsit5()))).(result) + sol_bfgs = (x -> isnan(x) ? convert(eltype(prob.p), Inf) : x).(sol_bfgs) + + minobj, ind = findmin(sol_bfgs) + + SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(OptimizationFunction(ode_loss), prob.p), opt, + view(result, ind), minobj) +end \ No newline at end of file