Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 351a975

Browse files
really start repo
1 parent b888b3d commit 351a975

19 files changed

+820
-1
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style = "sciml"

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Julia Computing, Inc.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

Project.toml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name = "SimpleNonlinearSolve"
2+
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
3+
authors = ["Kanav Gupta <kanav0610@gmail.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
8+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
9+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
12+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
14+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
16+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
17+
18+
[compat]
19+
ArrayInterfaceCore = "0.1.1"
20+
FiniteDiff = "2"
21+
ForwardDiff = "0.10.3"
22+
RecursiveArrayTools = "2"
23+
Reexport = "0.2, 1"
24+
SciMLBase = "1.32"
25+
Setfield = "0.7, 0.8, 1"
26+
StaticArrays = "0.12,1.0"
27+
UnPack = "1.0"
28+
julia = "1.6"
29+
30+
[extras]
31+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
32+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
33+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
34+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
35+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36+
37+
[targets]
38+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]

README.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,38 @@
1-
# SimpleNonlinearSolve
1+
# NonlinearSolve.jl
2+
3+
[![Join the chat at https://julialang.zulipchat.com #sciml-bridged](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/279055-sciml-bridged)
4+
[![Global Docs](https://img.shields.io/badge/docs-SciML-blue.svg)](https://docs.sciml.ai/NonlinearSolve/stable/)
5+
6+
[![codecov](https://codecov.io/gh/SciML/NonlinearSolve.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/NonlinearSolve.jl)
7+
[![Build Status](https://github.com/SciML/NonlinearSolve.jl/workflows/CI/badge.svg)](https://github.com/SciML/NonlinearSolve.jl/actions?query=workflow%3ACI)
8+
9+
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
10+
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
11+
12+
13+
14+
15+
Fast implementations of root finding algorithms in Julia that satisfy the SciML common interface.
16+
17+
For information on using the package,
18+
[see the stable documentation](https://docs.sciml.ai/NonlinearSolve/stable/). Use the
19+
[in-development documentation](https://docs.sciml.ai/NonlinearSolve/dev/) for the version of
20+
the documentation which contains the unreleased features.
21+
22+
## High Level Examples
23+
24+
```julia
25+
using NonlinearSolve, StaticArrays
26+
27+
f(u,p) = u .* u .- 2
28+
u0 = @SVector[1.0, 1.0]
29+
probN = NonlinearProblem{false}(f, u0)
30+
solver = solve(probN, NewtonRaphson(), tol = 1e-9)
31+
32+
## Bracketing Methods
33+
34+
f(u, p) = u .* u .- 2.0
35+
u0 = (1.0, 2.0) # brackets
36+
probB = NonlinearProblem(f, u0)
37+
sol = solve(probB, Falsi())
38+
```

src/SimpleNonlinearSolve.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module SimpleNonlinearSolve
2+
3+
using Reexport
4+
using UnPack: @unpack
5+
using FiniteDiff, ForwardDiff
6+
using ForwardDiff: Dual
7+
using Setfield
8+
using StaticArrays
9+
using RecursiveArrayTools
10+
using LinearAlgebra
11+
import ArrayInterfaceCore
12+
13+
@reexport using SciMLBase
14+
15+
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
16+
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
17+
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
18+
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
19+
20+
include("utils.jl")
21+
include("bisection.jl")
22+
include("falsi.jl")
23+
include("raphson.jl")
24+
include("ad.jl")
25+
26+
# DiffEq styled algorithms
27+
export Bisection, Falsi, SimpleNewtonRaphson
28+
29+
end # module

src/ad.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2+
f = prob.f
3+
p = value(prob.p)
4+
u0 = value(prob.u0)
5+
6+
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
7+
sol = solve(newprob, alg, args...; kwargs...)
8+
9+
uu = sol.u
10+
if p isa Number
11+
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
12+
else
13+
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
14+
end
15+
16+
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
17+
pp = prob.p
18+
sumfun = let f_x′ = -f_x
19+
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
20+
end
21+
partials = sum(sumfun, zip(f_p, pp))
22+
return sol, partials
23+
end
24+
25+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
26+
<:Dual{T, V, P}}, alg::SimpleNewtonRaphson,
27+
args...; kwargs...) where {iip, T, V, P}
28+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
29+
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
30+
retcode = sol.retcode)
31+
end
32+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
33+
<:AbstractArray{<:Dual{T, V, P}}},
34+
alg::SimpleNewtonRaphson, args...; kwargs...) where {iip, T, V, P}
35+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
36+
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
37+
retcode = sol.retcode)
38+
end
39+
40+
# avoid ambiguities
41+
for Alg in [Bisection]
42+
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T, V, P}},
43+
alg::$Alg, args...;
44+
kwargs...) where {uType, iip, T, V, P}
45+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
46+
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
47+
sol.resid; retcode = sol.retcode,
48+
left = Dual{T, V, P}(sol.left, partials),
49+
right = Dual{T, V, P}(sol.right, partials))
50+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
51+
end
52+
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip,
53+
<:AbstractArray{<:Dual{T, V, P}}},
54+
alg::$Alg, args...;
55+
kwargs...) where {uType, iip, T, V, P}
56+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
57+
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials),
58+
sol.resid; retcode = sol.retcode,
59+
left = Dual{T, V, P}(sol.left, partials),
60+
right = Dual{T, V, P}(sol.right, partials))
61+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
62+
end
63+
end

src/bisection.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
struct Bisection <: AbstractBracketingAlgorithm
2+
exact_left::Bool
3+
exact_right::Bool
4+
end
5+
6+
function Bisection(; exact_left = false, exact_right = false)
7+
Bisection(exact_left, exact_right)
8+
end
9+
10+
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000,
11+
kwargs...)
12+
f = Base.Fix2(prob.f, prob.p)
13+
left, right = prob.u0
14+
fl, fr = f(left), f(right)
15+
16+
if iszero(fl)
17+
return SciMLBase.build_solution(prob, alg, left, fl;
18+
retcode = ReturnCode.ExactSolutionLeft, left = left,
19+
right = right)
20+
end
21+
22+
i = 1
23+
if !iszero(fr)
24+
while i < maxiters
25+
mid = (left + right) / 2
26+
(mid == left || mid == right) &&
27+
return SciMLBase.build_solution(prob, alg, left, fl;
28+
retcode = ReturnCode.FloatingPointLimit,
29+
left = left, right = right)
30+
fm = f(mid)
31+
if iszero(fm)
32+
right = mid
33+
break
34+
end
35+
if sign(fl) == sign(fm)
36+
fl = fm
37+
left = mid
38+
else
39+
fr = fm
40+
right = mid
41+
end
42+
i += 1
43+
end
44+
end
45+
46+
while i < maxiters
47+
mid = (left + right) / 2
48+
(mid == left || mid == right) &&
49+
return SciMLBase.build_solution(prob, alg, left, fl;
50+
retcode = ReturnCode.FloatingPointLimit,
51+
left = left, right = right)
52+
fm = f(mid)
53+
if iszero(fm)
54+
right = mid
55+
fr = fm
56+
else
57+
left = mid
58+
fl = fm
59+
end
60+
i += 1
61+
end
62+
63+
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
64+
left = left, right = right)
65+
end

src/falsi.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
struct Falsi <: AbstractBracketingAlgorithm end
2+
3+
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000,
4+
kwargs...)
5+
f = Base.Fix2(prob.f, prob.p)
6+
left, right = prob.u0
7+
fl, fr = f(left), f(right)
8+
9+
if iszero(fl)
10+
return SciMLBase.build_solution(prob, alg, left, fl;
11+
retcode = ReturnCode.ExactSolutionLeft, left = left,
12+
right = right)
13+
end
14+
15+
i = 1
16+
if !iszero(fr)
17+
while i < maxiters
18+
if nextfloat_tdir(left, prob.u0...) == right
19+
return SciMLBase.build_solution(prob, alg, left, fl;
20+
retcode = ReturnCode.FloatingPointLimit,
21+
left = left, right = right)
22+
end
23+
mid = (fr * left - fl * right) / (fr - fl)
24+
for i in 1:10
25+
mid = max_tdir(left, prevfloat_tdir(mid, prob.u0...), prob.u0...)
26+
end
27+
if mid == right || mid == left
28+
break
29+
end
30+
fm = f(mid)
31+
if iszero(fm)
32+
right = mid
33+
break
34+
end
35+
if sign(fl) == sign(fm)
36+
fl = fm
37+
left = mid
38+
else
39+
fr = fm
40+
right = mid
41+
end
42+
i += 1
43+
end
44+
end
45+
46+
while i < maxiters
47+
mid = (left + right) / 2
48+
(mid == left || mid == right) &&
49+
return SciMLBase.build_solution(prob, alg, left, fl;
50+
retcode = ReturnCode.FloatingPointLimit,
51+
left = left, right = right)
52+
fm = f(mid)
53+
if iszero(fm)
54+
right = mid
55+
fr = fm
56+
elseif sign(fm) == sign(fl)
57+
left = mid
58+
fl = fm
59+
else
60+
right = mid
61+
fr = fm
62+
end
63+
i += 1
64+
end
65+
66+
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
67+
left = left, right = right)
68+
end

src/raphson.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
2+
function SimpleNewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
3+
diff_type = Val{:forward})
4+
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
5+
SciMLBase._unwrap_val(diff_type)}()
6+
end
7+
end
8+
9+
function SciMLBase.solve(prob::NonlinearProblem,
10+
alg::SimpleNewtonRaphson, args...; xatol = nothing, xrtol = nothing,
11+
maxiters = 1000, kwargs...)
12+
f = Base.Fix2(prob.f, prob.p)
13+
x = float(prob.u0)
14+
fx = float(prob.u0)
15+
T = typeof(x)
16+
17+
if SciMLBase.isinplace(prob)
18+
error("SimpleNewtonRaphson currently only supports out-of-place nonlinear problems")
19+
end
20+
21+
atol = xatol !== nothing ? xatol : oneunit(eltype(T)) * (eps(one(eltype(T))))^(4 // 5)
22+
rtol = xrtol !== nothing ? xrtol : eps(one(eltype(T)))^(4 // 5)
23+
24+
if typeof(x) <: Number
25+
xo = oftype(one(eltype(x)), Inf)
26+
else
27+
xo = map(x -> oftype(one(eltype(x)), Inf), x)
28+
end
29+
30+
for i in 1:maxiters
31+
if alg_autodiff(alg)
32+
fx, dfx = value_derivative(f, x)
33+
elseif x isa AbstractArray
34+
fx = f(x)
35+
dfx = FiniteDiff.finite_difference_jacobian(f, x, diff_type(alg), eltype(x), fx)
36+
else
37+
fx = f(x)
38+
dfx = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg), eltype(x),
39+
fx)
40+
end
41+
iszero(fx) &&
42+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
43+
Δx = dfx \ fx
44+
x -= Δx
45+
if isapprox(x, xo, atol = atol, rtol = rtol)
46+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
47+
end
48+
xo = x
49+
end
50+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
51+
end

0 commit comments

Comments
 (0)