From 0718c98c46607e5f5ac555deaa359fb67b2d7d42 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Sat, 12 Apr 2025 23:04:11 +0200 Subject: [PATCH 1/5] Support in-place interpolation of symbolic idxs --- src/solutions/ode_solutions.jl | 41 ++++++++++++++++++++++++++- test/downstream/solution_interface.jl | 21 ++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index dcae2c08f..e5274dc15 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -213,6 +213,7 @@ function is_discrete_expression(indp, expr) length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries() end +# These are the two main documented user-facing interpolation API functions (out-of-place and in-place versions) function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} if t isa IndexedClock @@ -225,9 +226,12 @@ function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing if t isa IndexedClock t = canonicalize_indexed_clock(t, sol) end - sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) + sol(v, t, deriv, idxs, continuity) end +# Below are many internal dispatches for different combinations of arguments to the main API +# TODO: could use a clever rewrite, since a lot of reused code has accumulated + function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Nothing, continuity) where {deriv} sol.interp(t, idxs, deriv, sol.prob.p, continuity) @@ -365,6 +369,41 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, return DiffEqArray(u, t, p, sol; discretes) end +function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs, + continuity) where {deriv} + symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") + error_if_observed_derivative(sol, idxs, deriv) + p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + getter = getsym(sol, idxs) + if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) + u = zeros(eltype(sol), size(sol)[1]) + v .= map(eachindex(t)) do ti + sol.interp(u, t[ti], nothing, deriv, p, continuity) + return getter(ProblemState(; u = u, p = p, t = t[ti])) + end + return v + end + error("In-place interpolation with discretes is not implemented.") +end +function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, + idxs::AbstractVector, continuity) where {deriv} + if symbolic_type(idxs) == NotSymbolic() && isempty(idxs) + return map(_ -> eltype(eltype(sol.u))[], t) + end + error_if_observed_derivative(sol, idxs, deriv) + p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + getter = getsym(sol, idxs) + if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) + u = zeros(eltype(sol), size(sol)[1]) + v .= map(eachindex(t)) do ti + sol.interp(u, t[ti], nothing, deriv, p, continuity) + return getter(ProblemState(; u = u, p = p, t = t[ti])) + end + return v + end + error("In-place interpolation with discretes is not implemented.") +end + struct DDESolutionHistoryWrapper{T} sol::T end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index b59424dd2..0bf30f1bb 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -148,6 +148,27 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2) sol10 = sol(0.1, idxs = 2) @test sol10 isa Real +# in-place interpolation with single (unknown) symbolic index +ts = 0.0:0.1:10.0 +out = zeros(eltype(sol), size(ts)) +idxs = unknowns(sys)[1] +@test sol(out, ts; idxs) == sol(ts; idxs) +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + +# in-place interpolation with single (observed) symbolic index +idxs = observed(sys)[1].lhs +@test sol(out, ts; idxs) == sol(ts; idxs) +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + +# in-place interpolation with multiple (unknown+observed) symbolic indices +idxs = [unknowns(sys)[1], observed(sys)[1].lhs] +out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)] +@test sol(out, ts; idxs) == sol(ts; idxs).u +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + @testset "Plot idxs" begin @variables x(t) y(t) @parameters p From 469d1c017f0abfc684e4ce83fa653be1f95b8273 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 22 Apr 2025 21:00:51 +0200 Subject: [PATCH 2/5] Restore old in-place interpolation of unspecified and integer idxs --- src/solutions/ode_solutions.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index e5274dc15..425174281 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -369,6 +369,14 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, return DiffEqArray(u, t, p, sol; discretes) end +function (sol::AbstractODESolution)(v::AbstractArray, t::Number, ::Type{deriv}, + idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} + return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) +end +function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, + idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} + return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) +end function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") From 9d1eab7529494f7c08defff565c75fce2de6e196 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Tue, 22 Apr 2025 21:01:32 +0200 Subject: [PATCH 3/5] Resolve MTK vs SII import name clash in test --- test/downstream/solution_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 0bf30f1bb..d030f92e0 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -1,7 +1,7 @@ using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test using StochasticDiffEq using SymbolicIndexingInterface -using ModelingToolkit: t_nounits as t, D_nounits as D +using ModelingToolkit: observed, t_nounits as t, D_nounits as D using Plots: Plots, plot ### Tests on non-layered model (everything should work). ### From 2737f519b6ad19a36bce9998849adfa5c8823d65 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Fri, 25 Apr 2025 23:44:16 +0200 Subject: [PATCH 4/5] Format --- src/solutions/ode_solutions.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 425174281..9f66a0ccb 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -370,15 +370,17 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, end function (sol::AbstractODESolution)(v::AbstractArray, t::Number, ::Type{deriv}, - idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} + idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) end -function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, - idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} +function (sol::AbstractODESolution)( + v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, + idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) end -function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs, - continuity) where {deriv} +function (sol::AbstractODESolution)( + v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs, + continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing @@ -393,8 +395,9 @@ function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number end error("In-place interpolation with discretes is not implemented.") end -function (sol::AbstractODESolution)(v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, - idxs::AbstractVector, continuity) where {deriv} +function (sol::AbstractODESolution)( + v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, + idxs::AbstractVector, continuity) where {deriv} if symbolic_type(idxs) == NotSymbolic() && isempty(idxs) return map(_ -> eltype(eltype(sol.u))[], t) end From be9dca15e482d0810cb5fe6a664f9dbe51a99be2 Mon Sep 17 00:00:00 2001 From: Herman Sletmoen Date: Sat, 26 Apr 2025 17:21:43 +0200 Subject: [PATCH 5/5] Incorporate suggestions: assign in-place when possible; combine dispatches to one; test in-place symbolic interpolation with one time value --- src/solutions/ode_solutions.jl | 49 +++++++++++---------------- test/downstream/solution_interface.jl | 8 +++++ 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 9f66a0ccb..f70fb3a77 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -369,46 +369,37 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, return DiffEqArray(u, t, p, sol; discretes) end -function (sol::AbstractODESolution)(v::AbstractArray, t::Number, ::Type{deriv}, - idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} - return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) -end function (sol::AbstractODESolution)( - v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, + v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv}, idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) end function (sol::AbstractODESolution)( - v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, idxs, + v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv}, idxs, continuity) where {deriv} - symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") - error_if_observed_derivative(sol, idxs, deriv) - p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - getter = getsym(sol, idxs) - if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) - u = zeros(eltype(sol), size(sol)[1]) - v .= map(eachindex(t)) do ti - sol.interp(u, t[ti], nothing, deriv, p, continuity) - return getter(ProblemState(; u = u, p = p, t = t[ti])) - end - return v - end - error("In-place interpolation with discretes is not implemented.") -end -function (sol::AbstractODESolution)( - v::AbstractArray, t::AbstractVector{<:Number}, ::Type{deriv}, - idxs::AbstractVector, continuity) where {deriv} - if symbolic_type(idxs) == NotSymbolic() && isempty(idxs) - return map(_ -> eltype(eltype(sol.u))[], t) + if idxs isa AbstractArray && any(idx -> idx == NotSymbolic(), symbolic_type.(idxs)) || + !(idxs isa AbstractArray) && symbolic_type(idxs) == NotSymbolic() + error("Incorrect specification of `idxs`") end error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - getter = getsym(sol, idxs) + getter = getsym(sol, idxs) # TODO: breaks type inference and allocates if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) u = zeros(eltype(sol), size(sol)[1]) - v .= map(eachindex(t)) do ti - sol.interp(u, t[ti], nothing, deriv, p, continuity) - return getter(ProblemState(; u = u, p = p, t = t[ti])) + if t isa AbstractVector + for ti in eachindex(t) + sol.interp(u, t[ti], nothing, deriv, p, continuity) + state = ProblemState(; u = u, p = p, t = t[ti]) + if eltype(v) <: Number + v[ti] = getter(state) + else + v[ti] .= getter(state) + end + end + else # t isa Number + sol.interp(u, t, nothing, deriv, p, continuity) + state = ProblemState(; u = u, p = p, t = t) + v .= getter(state) end return v end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index d030f92e0..00790e40a 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -169,6 +169,14 @@ out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)] @test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) @test_nowarn @inferred sol(out, ts; idxs) +# same as above, but with one time value only +@test sol(out[1], ts[1]; idxs) == sol(ts[1]; idxs) +#@test (@allocated sol(out[1], ts[1]; idxs)) < (@allocated sol(ts[1]; idxs)) # TODO: reduce allocations and fix +@test_nowarn @inferred sol(out[1], ts[1]; idxs) + +idxs = [unknowns(sys)[1], 1] +@test_throws "Incorrect specification of `idxs`" sol(out, ts; idxs) + @testset "Plot idxs" begin @variables x(t) y(t) @parameters p