Skip to content

Commit 8a313af

Browse files
Merge pull request #3985 from ChrisRackauckas-Claude/fix-sde-noise-kwarg
Fix noise kwarg propagation for SDEProblem(f::SDESystem)
2 parents 865e339 + f3b84be commit 8a313af

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

src/problems/sdeproblem.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,20 @@ end
7777
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
7878
eval_module, check_compatibility, sparse, expression, kwargs...)
7979

80-
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
80+
# Only calculate noise and noise_rate_prototype if not provided by user
81+
if !haskey(kwargs, :noise) && !haskey(kwargs, :noise_rate_prototype)
82+
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
83+
elseif !haskey(kwargs, :noise)
84+
noise, _ = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
85+
noise_rate_prototype = kwargs[:noise_rate_prototype]
86+
elseif !haskey(kwargs, :noise_rate_prototype)
87+
_, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
88+
noise = kwargs[:noise]
89+
else
90+
noise = kwargs[:noise]
91+
noise_rate_prototype = kwargs[:noise_rate_prototype]
92+
end
93+
8194
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
8295
op, kwargs...)
8396

test/sdesystem.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
22
using StochasticDiffEq, OrdinaryDiffEq, SparseArrays
3+
using DiffEqNoiseProcess: NoiseWrapper
34
using Random, Test
45
using Setfield
56
using Statistics
@@ -953,3 +954,47 @@ end
953954
@test ModelingToolkit.isbrownian(p)
954955
@test ModelingToolkit.isbrownian(q)
955956
end
957+
958+
@testset "noise kwarg propagation (issue #3664)" begin
959+
@parameters σ ρ β
960+
@variables x(tt) y(tt) z(tt)
961+
962+
u0 = [1.0, 0.0, 0.0]
963+
T = (0.0, 5.0)
964+
965+
eqs = [D(x) ~ σ * (y - x),
966+
D(y) ~ x *- z) - y,
967+
D(z) ~ x * y - β * z]
968+
noiseeqs = [3.0,
969+
3.0,
970+
3.0]
971+
@mtkbuild sde_lorentz = SDESystem(eqs, noiseeqs, tt, [x, y, z], [σ, ρ, β])
972+
parammap = [σ, ρ, β] .=> [10, 28.0, 8 / 3]
973+
974+
# Test that user-provided noise is respected
975+
Random.seed!(1)
976+
noise1 = StochasticDiffEq.RealWienerProcess(0.0, 0.0, 0.0; save_everystep = true)
977+
u0_dict = Dict(unknowns(sde_lorentz) .=> u0)
978+
prob1 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise1)
979+
sol1 = solve(prob1, SRIW1())
980+
981+
# Verify noise was actually used (curW should be modified)
982+
@test noise1.curW != 0.0
983+
984+
# Test that using the same noise via NoiseWrapper gives deterministic results
985+
noise2 = NoiseWrapper(noise1)
986+
prob2 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise2)
987+
sol2 = solve(prob2, SRIW1())
988+
989+
# Same noise should give same results
990+
@test sol1.u[end] sol2.u[end]
991+
992+
# Test that without providing noise, different results are obtained
993+
Random.seed!(1)
994+
prob3 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T)
995+
Random.seed!(2)
996+
prob4 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T)
997+
sol3 = solve(prob3, SRIW1(), seed = 1)
998+
sol4 = solve(prob4, SRIW1(), seed = 2)
999+
@test !(sol3.u[end] sol4.u[end])
1000+
end

0 commit comments

Comments
 (0)