Skip to content

Commit 34c4f13

Browse files
Merge pull request #470 from chriselrod/addweightedensembleproblem
Add `WeightedEnsembleProblem`
2 parents f2eab6b + 3b994a4 commit 34c4f13

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/ensemble/ensemble_problems.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,22 @@ function EnsembleProblem(; prob,
3535
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
3636
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
3737
end
38+
39+
struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
40+
ensembleprob::T1
41+
weights::T2
42+
end
43+
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
44+
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
45+
f === :weights && return getfield(e, :weights)
46+
f === :ensembleprob && return getfield(e, :ensembleprob)
47+
return getproperty(getfield(e, :ensembleprob), f)
48+
end
49+
function WeightedEnsembleProblem(args...; weights, kwargs...)
50+
# TODO: allow skipping checks?
51+
@assert sum(weights) 1
52+
ep = EnsembleProblem(args...; kwargs...)
53+
@assert length(ep.prob) == length(weights)
54+
WeightedEnsembleProblem(ep, weights)
55+
end
56+

src/ensemble/ensemble_solutions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ function EnsembleSolution(sim::T, elapsedTime,
4646
converged)
4747
end
4848

49+
struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
50+
ensol::T1
51+
weights::Vector{T2}
52+
function WeightedEnsembleSolution(ensol, weights)
53+
@assert length(weights) == length(ensol)
54+
new{typeof(ensol), eltype(weights)}(ensol, weights)
55+
end
56+
end
57+
4958
function Base.reverse(sim::EnsembleSolution)
5059
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged)
5160
end
@@ -213,3 +222,7 @@ end
213222
function (sol::AbstractEnsembleSolution)(args...; kwargs...)
214223
[s(args...; kwargs...) for s in sol]
215224
end
225+
226+
Base.@propagate_inbounds function Base.getindex(sol::WeightedEnsembleSolution, S)
227+
return [sum(stack(sol.weights .* sol.ensol[:, S]), dims = 2)]
228+
end

0 commit comments

Comments
 (0)