Skip to content

Commit fa3519c

Browse files
Merge pull request #464 from oscardssmith/EnsembleSolution-indexing
make indexing EnsembleSolution work
2 parents 32cae24 + 0c3f019 commit fa3519c

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ tighten_container_eltype(u) = u
4545
function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}},
4646
alg::Union{AbstractDEAlgorithm, Nothing},
4747
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
48-
solve(prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
48+
# TODO: @invoke
49+
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
50+
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
4951
end
5052

5153
function __solve(prob::AbstractEnsembleProblem,

src/ensemble/ensemble_problems.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
1515
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
1616
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
1717
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
18-
EnsembleProblem(prob; kwargs..., prob_func=DEFAULT_VECTOR_PROB_FUNC)
18+
# TODO: @invoke
19+
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
1920
end
2021
function EnsembleProblem(prob;
2122
output_func = DEFAULT_OUTPUT_FUNC,

src/ensemble/ensemble_solutions.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,19 @@ end
197197
end
198198
end
199199
end
200+
201+
202+
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s)
203+
return [xi[s] for xi in x]
204+
end
205+
206+
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
207+
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
208+
end
209+
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...)
210+
return [xi[args...] for xi in x]
211+
end
212+
213+
function (sol::AbstractEnsembleSolution)(args...; kwargs...)
214+
[s(args...; kwargs...) for s in sol]
215+
end
Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1-
using OrdinaryDiffEq, Test
1+
using ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@variables t, x(t)
4+
D = Differential(t)
5+
6+
@named sys1 = ODESystem([D(x) ~ 1.1*x])
7+
@named sys2 = ODESystem([D(x) ~ 1.2*x])
8+
9+
prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0))
10+
prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0))
211

3-
prob1 = ODEProblem((u, p, t) -> 0.99u, 0.55, (0.0, 1.1))
4-
prob1 = ODEProblem((u, p, t) -> 1.0u, 0.45, (0.0, 0.9))
5-
output_func(sol, i) = (last(sol), false)
612
# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
7-
ensemble_prob = EnsembleProblem([prob1, prob2], output_func = output_func)
8-
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads())
13+
ensemble_prob = EnsembleProblem([prob1, prob2])
14+
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
15+
@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4)
16+
# Ensemble is a recursive array
17+
@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1)
18+
# TODO: fix the interpolation
19+
@test sol(1.0, idxs=[x]) last.(sol[:, x], 1)

0 commit comments

Comments
 (0)