|
1 | 1 | using ModelingToolkit, StaticArrays, LinearAlgebra |
2 | 2 | using StochasticDiffEq, OrdinaryDiffEq, SparseArrays |
| 3 | +using DiffEqNoiseProcess: NoiseWrapper |
3 | 4 | using Random, Test |
4 | 5 | using Setfield |
5 | 6 | using Statistics |
|
953 | 954 | @test ModelingToolkit.isbrownian(p) |
954 | 955 | @test ModelingToolkit.isbrownian(q) |
955 | 956 | end |
| 957 | + |
| 958 | +@testset "noise kwarg propagation (issue #3664)" begin |
| 959 | + @parameters σ ρ β |
| 960 | + @variables x(tt) y(tt) z(tt) |
| 961 | + |
| 962 | + u0 = [1.0, 0.0, 0.0] |
| 963 | + T = (0.0, 5.0) |
| 964 | + |
| 965 | + eqs = [D(x) ~ σ * (y - x), |
| 966 | + D(y) ~ x * (ρ - z) - y, |
| 967 | + D(z) ~ x * y - β * z] |
| 968 | + noiseeqs = [3.0, |
| 969 | + 3.0, |
| 970 | + 3.0] |
| 971 | + @mtkbuild sde_lorentz = SDESystem(eqs, noiseeqs, tt, [x, y, z], [σ, ρ, β]) |
| 972 | + parammap = [σ, ρ, β] .=> [10, 28.0, 8 / 3] |
| 973 | + |
| 974 | + # Test that user-provided noise is respected |
| 975 | + Random.seed!(1) |
| 976 | + noise1 = StochasticDiffEq.RealWienerProcess(0.0, 0.0, 0.0; save_everystep = true) |
| 977 | + u0_dict = Dict(unknowns(sde_lorentz) .=> u0) |
| 978 | + prob1 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise1) |
| 979 | + sol1 = solve(prob1, SRIW1()) |
| 980 | + |
| 981 | + # Verify noise was actually used (curW should be modified) |
| 982 | + @test noise1.curW != 0.0 |
| 983 | + |
| 984 | + # Test that using the same noise via NoiseWrapper gives deterministic results |
| 985 | + noise2 = NoiseWrapper(noise1) |
| 986 | + prob2 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T; noise = noise2) |
| 987 | + sol2 = solve(prob2, SRIW1()) |
| 988 | + |
| 989 | + # Same noise should give same results |
| 990 | + @test sol1.u[end] ≈ sol2.u[end] |
| 991 | + |
| 992 | + # Test that without providing noise, different results are obtained |
| 993 | + Random.seed!(1) |
| 994 | + prob3 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T) |
| 995 | + Random.seed!(2) |
| 996 | + prob4 = SDEProblem(sde_lorentz, merge(u0_dict, Dict(parammap)), T) |
| 997 | + sol3 = solve(prob3, SRIW1(), seed = 1) |
| 998 | + sol4 = solve(prob4, SRIW1(), seed = 2) |
| 999 | + @test !(sol3.u[end] ≈ sol4.u[end]) |
| 1000 | +end |
0 commit comments