Skip to content

Commit 749d482

Browse files
run the formatter
1 parent 811bc33 commit 749d482

File tree

2 files changed

+91
-48
lines changed

2 files changed

+91
-48
lines changed

src/MATLABDiffEq.jl

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,21 @@ struct ode15i <: MATLABAlgorithm end
1616

1717
function DiffEqBase.__solve(
1818
prob::DiffEqBase.AbstractODEProblem{uType,tupType,isinplace},
19-
alg::AlgType,timeseries=[],ts=[],ks=[];
20-
saveat=eltype(tupType)[],timeseries_errors=true,reltol = 1e-3, abstol = 1e-6,
19+
alg::AlgType,
20+
timeseries = [],
21+
ts = [],
22+
ks = [];
23+
saveat = eltype(tupType)[],
24+
timeseries_errors = true,
25+
reltol = 1e-3,
26+
abstol = 1e-6,
2127
callback = nothing,
22-
kwargs...) where {uType,tupType,isinplace,AlgType<:MATLABAlgorithm}
28+
kwargs...,
29+
) where {uType,tupType,isinplace,AlgType<:MATLABAlgorithm}
2330

2431
tType = eltype(tupType)
2532

26-
if prob.tspan[end]-prob.tspan[1]<tType(0)
33+
if prob.tspan[end] - prob.tspan[1] < tType(0)
2734
error("final time must be greater than starting time. Aborting.")
2835
end
2936

@@ -32,9 +39,9 @@ function DiffEqBase.__solve(
3239

3340
if typeof(saveat) <: Number
3441
tspan = Array(prob.tspan[1]:saveat:prob.tspan[2])
35-
tspan = sort(unique([prob.tspan[1];tspan;prob.tspan[2]]))
42+
tspan = sort(unique([prob.tspan[1]; tspan; prob.tspan[2]]))
3643
else
37-
tspan = sort(unique([prob.tspan[1];saveat;prob.tspan[2]]))
44+
tspan = sort(unique([prob.tspan[1]; saveat; prob.tspan[2]]))
3845
end
3946

4047
sizeu = size(prob.u0)
@@ -49,16 +56,20 @@ function DiffEqBase.__solve(
4956

5057
sys = modelingtoolkitize(prob)
5158

52-
matstr = ModelingToolkit.build_function(map(x->x.rhs,equations(sys)),states(sys),
53-
parameters(sys),independent_variables(sys)[1],
54-
target = ModelingToolkit.MATLABTarget())
59+
matstr = ModelingToolkit.build_function(
60+
map(x -> x.rhs, equations(sys)),
61+
states(sys),
62+
parameters(sys),
63+
independent_variables(sys)[1],
64+
target = ModelingToolkit.MATLABTarget(),
65+
)
5566

5667
# Send the variables
57-
put_variable(get_default_msession(),:tspan,tspan)
58-
put_variable(get_default_msession(),:u0,u0)
59-
put_variable(get_default_msession(),:internal_var___p,prob.p)
60-
put_variable(get_default_msession(),:reltol,reltol)
61-
put_variable(get_default_msession(),:abstol,abstol)
68+
put_variable(get_default_msession(), :tspan, tspan)
69+
put_variable(get_default_msession(), :u0, u0)
70+
put_variable(get_default_msession(), :internal_var___p, prob.p)
71+
put_variable(get_default_msession(), :reltol, reltol)
72+
put_variable(get_default_msession(), :abstol, abstol)
6273

6374
# Send the function over
6475
eval_string(matstr)
@@ -75,29 +86,59 @@ function DiffEqBase.__solve(
7586

7687
# Reshape the result if needed
7788
if uType <: AbstractArray
78-
timeseries = Vector{uType}(undef,length(ts))
79-
for i=1:length(ts)
80-
timeseries[i] = @view timeseries_tmp[i,:]
89+
timeseries = Vector{uType}(undef, length(ts))
90+
for i = 1:length(ts)
91+
timeseries[i] = @view timeseries_tmp[i, :]
8192
end
8293
else
8394
timeseries = timeseries_tmp
8495
end
8596

8697
destats = buildDEStats(solstats)
8798

88-
DiffEqBase.build_solution(prob,alg,ts,timeseries,
89-
timeseries_errors = timeseries_errors,destats = destats)
99+
DiffEqBase.build_solution(
100+
prob,
101+
alg,
102+
ts,
103+
timeseries,
104+
timeseries_errors = timeseries_errors,
105+
destats = destats,
106+
)
90107
end
91108

92109
function buildDEStats(solverstats::Dict)
93110

94111
destats = DiffEqBase.DEStats(0)
95-
destats.nf = if (haskey(solverstats, "nfevals")) solverstats["nfevals"] else 0 end
96-
destats.nreject = if (haskey(solverstats, "nfailed")) solverstats["nfailed"] else 0 end
97-
destats.naccept = if (haskey(solverstats, "nsteps")) solverstats["nsteps"] else 0 end
98-
destats.nsolve = if (haskey(solverstats, "nsolves")) solverstats["nsolves"] else 0 end
99-
destats.njacs = if (haskey(solverstats, "npds")) solverstats["npds"] else 0 end
100-
destats.nw = if (haskey(solverstats, "ndecomps")) solverstats["ndecomps"] else 0 end
112+
destats.nf = if (haskey(solverstats, "nfevals"))
113+
solverstats["nfevals"]
114+
else
115+
0
116+
end
117+
destats.nreject = if (haskey(solverstats, "nfailed"))
118+
solverstats["nfailed"]
119+
else
120+
0
121+
end
122+
destats.naccept = if (haskey(solverstats, "nsteps"))
123+
solverstats["nsteps"]
124+
else
125+
0
126+
end
127+
destats.nsolve = if (haskey(solverstats, "nsolves"))
128+
solverstats["nsolves"]
129+
else
130+
0
131+
end
132+
destats.njacs = if (haskey(solverstats, "npds"))
133+
solverstats["npds"]
134+
else
135+
0
136+
end
137+
destats.nw = if (haskey(solverstats, "ndecomps"))
138+
solverstats["ndecomps"]
139+
else
140+
0
141+
end
101142
destats
102143
end
103144

test/runtests.jl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,35 @@
11
using DiffEqBase, MATLABDiffEq, ParameterizedFunctions, Test
22

33
f = @ode_def_bare LotkaVolterra begin
4-
dx = a*x - b*x*y
5-
dy = -c*y + d*x*y
4+
dx = a * x - b * x * y
5+
dy = -c * y + d * x * y
66
end a b c d
7-
p = [1.5,1,3,1]
8-
tspan = (0.0,10.0)
9-
u0 = [1.0,1.0]
10-
prob = ODEProblem(f,u0,tspan,p)
11-
sol = solve(prob,MATLABDiffEq.ode45())
7+
p = [1.5, 1, 3, 1]
8+
tspan = (0.0, 10.0)
9+
u0 = [1.0, 1.0]
10+
prob = ODEProblem(f, u0, tspan, p)
11+
sol = solve(prob, MATLABDiffEq.ode45())
1212

13-
function lorenz(du,u,p,t)
14-
du[1] = 10.0(u[2]-u[1])
15-
du[2] = u[1]*(28.0-u[3]) - u[2]
16-
du[3] = u[1]*u[2] - (8/3)*u[3]
13+
function lorenz(du, u, p, t)
14+
du[1] = 10.0(u[2] - u[1])
15+
du[2] = u[1] * (28.0 - u[3]) - u[2]
16+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
1717
end
18-
u0 = [1.0;0.0;0.0]
19-
tspan = (0.0,100.0)
20-
prob = ODEProblem(lorenz,u0,tspan)
21-
sol = solve(prob,MATLABDiffEq.ode45())
18+
u0 = [1.0; 0.0; 0.0]
19+
tspan = (0.0, 100.0)
20+
prob = ODEProblem(lorenz, u0, tspan)
21+
sol = solve(prob, MATLABDiffEq.ode45())
2222

23-
algs = [MATLABDiffEq.ode23
24-
MATLABDiffEq.ode45
25-
MATLABDiffEq.ode113
26-
MATLABDiffEq.ode23s
27-
MATLABDiffEq.ode23t
28-
MATLABDiffEq.ode23tb
29-
MATLABDiffEq.ode15s]
23+
algs = [
24+
MATLABDiffEq.ode23
25+
MATLABDiffEq.ode45
26+
MATLABDiffEq.ode113
27+
MATLABDiffEq.ode23s
28+
MATLABDiffEq.ode23t
29+
MATLABDiffEq.ode23tb
30+
MATLABDiffEq.ode15s
31+
]
3032

3133
for alg in algs
32-
sol = solve(prob,alg())
34+
sol = solve(prob, alg())
3335
end

0 commit comments

Comments
 (0)