Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 46dd2c9

Browse files
setup with IntervalNonlinearProblem
1 parent 351a975 commit 46dd2c9

File tree

6 files changed

+37
-23
lines changed

6 files changed

+37
-23
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
3-
authors = ["Kanav Gupta <kanav0610@gmail.com>"]
3+
authors = ["SciML"]
44
version = "0.1.0"
55

66
[deps]
@@ -21,7 +21,7 @@ FiniteDiff = "2"
2121
ForwardDiff = "0.10.3"
2222
RecursiveArrayTools = "2"
2323
Reexport = "0.2, 1"
24-
SciMLBase = "1.32"
24+
SciMLBase = "1.73"
2525
Setfield = "0.7, 0.8, 1"
2626
StaticArrays = "0.12,1.0"
2727
UnPack = "1.0"

src/ad.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
22
f = prob.f
33
p = value(prob.p)
4-
u0 = value(prob.u0)
54

6-
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
5+
if prob isa IntervalNonlinearProblem
6+
tspan = value(prob.tspan)
7+
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
8+
else
9+
u0 = value(prob.u0)
10+
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
11+
end
12+
713
sol = solve(newprob, alg, args...; kwargs...)
814

915
uu = sol.u
@@ -39,7 +45,8 @@ end
3945

4046
# avoid ambiguities
4147
for Alg in [Bisection]
42-
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T, V, P}},
48+
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
49+
<:Dual{T, V, P}},
4350
alg::$Alg, args...;
4451
kwargs...) where {uType, iip, T, V, P}
4552
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -49,8 +56,12 @@ for Alg in [Bisection]
4956
right = Dual{T, V, P}(sol.right, partials))
5057
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
5158
end
52-
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip,
53-
<:AbstractArray{<:Dual{T, V, P}}},
59+
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
60+
<:AbstractArray{
61+
<:Dual{T,
62+
V,
63+
P}
64+
}},
5465
alg::$Alg, args...;
5566
kwargs...) where {uType, iip, T, V, P}
5667
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

src/bisection.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ function Bisection(; exact_left = false, exact_right = false)
77
Bisection(exact_left, exact_right)
88
end
99

10-
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000,
10+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...;
11+
maxiters = 1000,
1112
kwargs...)
1213
f = Base.Fix2(prob.f, prob.p)
13-
left, right = prob.u0
14+
left, right = prob.tspan
1415
fl, fr = f(left), f(right)
1516

1617
if iszero(fl)

src/falsi.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
struct Falsi <: AbstractBracketingAlgorithm end
22

3-
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000,
3+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
4+
maxiters = 1000,
45
kwargs...)
56
f = Base.Fix2(prob.f, prob.p)
6-
left, right = prob.u0
7+
left, right = prob.tspan
78
fl, fr = f(left), f(right)
89

910
if iszero(fl)
@@ -15,14 +16,14 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
1516
i = 1
1617
if !iszero(fr)
1718
while i < maxiters
18-
if nextfloat_tdir(left, prob.u0...) == right
19+
if nextfloat_tdir(left, prob.tspan...) == right
1920
return SciMLBase.build_solution(prob, alg, left, fl;
2021
retcode = ReturnCode.FloatingPointLimit,
2122
left = left, right = right)
2223
end
2324
mid = (fr * left - fl * right) / (fr - fl)
2425
for i in 1:10
25-
mid = max_tdir(left, prevfloat_tdir(mid, prob.u0...), prob.u0...)
26+
mid = max_tdir(left, prevfloat_tdir(mid, prob.tspan...), prob.tspan...)
2627
end
2728
if mid == right || mid == left
2829
break

src/raphson.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
22
function SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
3-
diff_type = Val{:forward})
3+
diff_type = Val{:forward})
44
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
5-
SciMLBase._unwrap_val(diff_type)}()
5+
SciMLBase._unwrap_val(diff_type)}()
66
end
77
end
88

99
function SciMLBase.solve(prob::NonlinearProblem,
10-
alg::SimpleNewtonRaphson, args...; xatol = nothing, xrtol = nothing,
10+
alg::SimpleNewtonRaphson, args...; xatol = nothing,
11+
xrtol = nothing,
1112
maxiters = 1000, kwargs...)
1213
f = Base.Fix2(prob.f, prob.p)
1314
x = float(prob.u0)

test/basictests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ for p in 1.1:0.1:100.0
5757
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
5858
end
5959

60-
u0 = (1.0, 20.0)
60+
tspan = (1.0, 20.0)
6161
# Falsi
6262
g = function (p)
63-
probN = NonlinearProblem{false}(f, typeof(p).(u0), p)
63+
probN = IntervalNonlinearProblem{false}(f, typeof(p).(tspan), p)
6464
sol = solve(probN, Falsi())
6565
return sol.left
6666
end
@@ -70,13 +70,13 @@ for p in 1.1:0.1:100.0
7070
@test ForwardDiff.derivative(g, p) 1 / (2 * sqrt(p))
7171
end
7272

73-
f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
73+
f, tspan = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0)
7474
t = (p) -> [sqrt(p[2] / p[1])]
7575
p = [0.9, 50.0]
7676
for alg in [Bisection(), Falsi()]
7777
global g, p
7878
g = function (p)
79-
probN = NonlinearProblem{false}(f, u0, p)
79+
probN = IntervalNonlinearProblem{false}(f, tspan, p)
8080
sol = solve(probN, Bisection())
8181
return [sol.left]
8282
end
@@ -115,8 +115,8 @@ for u0 in [1.0, [1, 1.0]]
115115
end
116116

117117
# Bisection Tests
118-
f, u0 = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
119-
probB = NonlinearProblem(f, u0)
118+
f, tspan = (u, p) -> u .* u .- 2.0, (1.0, 2.0)
119+
probB = IntervalNonlinearProblem(f, tspan)
120120

121121
# Falsi
122122
sol = solve(probB, Falsi())
@@ -135,7 +135,7 @@ f = function (u, p)
135135
return 0.0
136136
end
137137
end
138-
probB = NonlinearProblem(f, (0.0, 4.0))
138+
probB = IntervalNonlinearProblem(f, (0.0, 4.0))
139139

140140
sol = solve(probB, Bisection(; exact_left = true))
141141
@test f(sol.left, nothing) < 0.0

0 commit comments

Comments
 (0)