Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions benchmarks/adam_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ optprob = Optimization.OptimizationProblem(optf, p_nn)
@time res = Optimization.solve(optprob, ADAM(0.05), maxiters = 100)
@show res.objective

@benchmark Optimization.solve(optprob, ADAM(0.05), maxiters = 100)
# @benchmark Optimization.solve(optprob, ADAM(0.05), maxiters = 100)

## PSOGPU stuff

Expand Down Expand Up @@ -129,16 +129,16 @@ solver_cache = (; losses, gpu_particles, gpu_data, gbest)
prob_func = prob_func,
maxiters = 100)

@benchmark PSOGPU.parameter_estim_ode!($prob_nn,
$(deepcopy(solver_cache)),
$lb,
$ub;
saveat = tsteps,
dt = 0.1f0,
prob_func = prob_func,
maxiters = 100)
# @benchmark PSOGPU.parameter_estim_ode!($prob_nn,
# $(deepcopy(solver_cache)),
# $lb,
# $ub;
# saveat = tsteps,
# dt = 0.1f0,
# prob_func = prob_func,
# maxiters = 100)

@show gsol.best
@show gsol.cost

using Plots

Expand All @@ -151,7 +151,16 @@ using Plots
plt = scatter(tsteps, data[1, :], label = "data")

pred_pso = predict_neuralode((sc, gsol.position))
scatter!(plt, tsteps, pred[1, :], label = "PSO prediction")
scatter!(plt, tsteps, pred_pso[1, :], label = "PSO prediction")

pred_adam = predict_neuralode((sc, res.u))
scatter!(plt, tsteps, pred_adam[1, :], label = "Adam prediction")

@time gsol1 = PSOGPU.parameter_estim_odehybrid(prob_nn,
solver_cache,
lb,
ub;
saveat = tsteps,
dt = 0.1f0,
prob_func = prob_func,
maxiters = 100)
4 changes: 2 additions & 2 deletions src/PSOGPU.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module PSOGPU

using SciMLBase, StaticArrays, Setfield, KernelAbstractions
using SciMLBase, StaticArrays, Setfield, KernelAbstractions, OrdinaryDiffEq
using QuasiMonteCarlo, Optimization, SimpleNonlinearSolve, ForwardDiff
import Adapt
import Enzyme: autodiff_deferred, Active, Reverse
Expand Down Expand Up @@ -58,12 +58,12 @@ end

include("./algorithms.jl")
include("./utils.jl")
include("./ode_pso.jl")
include("./kernels.jl")
include("./lowerlevel_solve.jl")
include("./solve.jl")
include("./bfgs.jl")
include("./hybrid.jl")
include("./ode_pso.jl")

export ParallelPSOKernel,
ParallelSyncPSOKernel, ParallelPSOArray, SerialPSO, OptimizationProblem, solve
Expand Down
80 changes: 80 additions & 0 deletions src/ode_pso.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,83 @@ function parameter_estim_ode!(prob::ODEProblem, cache,
end
return gbest
end

@inline function ode_loss(particle, second_arg; kwargs...)
prob, ode_alg = second_arg
prob = remake(prob, p = particle)

us = OrdinaryDiffEq.__solve(prob,
ode_alg; kwargs...).u
us
end

function parameter_estim_odehybrid(prob::ODEProblem, cache,
lb,
ub;
ode_alg = GPUTsit5(),
prob_func = default_prob_func,
w = 0.72980f0,
wdamp = 1.0f0,
maxiters = 100,
local_maxiters = 100,
abstol = nothing,
reltol = nothing,
kwargs...)

(losses, gpu_particles, gpu_data, gbest) = cache
backend = get_backend(gpu_particles)
update_states! = PSOGPU._update_particle_states!(backend)
update_costs! = PSOGPU._update_particle_costs!(backend)

improb = make_prob_compatible(prob)

for i in 1:maxiters
update_states!(gpu_particles,
lb,
ub,
gbest,
w;
ndrange = length(gpu_particles))

probs = prob_func.(Ref(improb), gpu_particles)

ts, us = vectorized_asolve(probs,
prob,
ode_alg; kwargs...)

sum!(losses, (map(x -> sum(x .^ 2), gpu_data .- us)))

update_costs!(losses, gpu_particles; ndrange = length(losses))

w = w * wdamp
end

_f = Base.Fix2(ode_loss, (improb, Tsit5()))
# autodiff_deferred(Reverse, sum(abs2, gpu_data .- us), Active, Active(particle))[1][1]
f = x -> ForwardDiff.gradient(sum(abs2, gpu_data .- _f(x)), x)

kernel = simplebfgs_run!(backend)
x0s = get_pos.(gpu_particles)
result = KernelAbstractions.allocate(backend, typeof(prob.u0), length(x0s))
nlprob = NonlinearProblem{false}((x, p) -> f(x), prob.p)

nlalg = SimpleBroyden(; linesearch = Val(true))

kernel(nlprob,
x0s,
result,
nlalg,
local_maxiters,
abstol,
reltol;
ndrange = length(x0s))

t1 = time()
sol_bfgs = (x -> ode_loss(x, (prob, Tsit5()))).(result)
sol_bfgs = (x -> isnan(x) ? convert(eltype(prob.p), Inf) : x).(sol_bfgs)

minobj, ind = findmin(sol_bfgs)

SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(OptimizationFunction(ode_loss), prob.p), opt,
view(result, ind), minobj)
end