Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 4f9ee0e

Browse files
authored
Merge branch 'main' into qqy/downgrade_ci
2 parents 1c017bb + 0a1fdce commit 4f9ee0e

30 files changed

+850
-379
lines changed

.buildkite/pipeline.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
steps:
2+
- label: "Julia 1"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1"
6+
- JuliaCI/julia-test#v1:
7+
agents:
8+
queue: "juliagpu"
9+
cuda: "*"
10+
timeout_in_minutes: 30
11+
# Don't run Buildkite if the commit message includes the text [skip tests]
12+
if: build.message !~ /\[skip tests\]/
13+
14+
env:
15+
GROUP: CUDA
16+
JULIA_PKG_SERVER: ""

.github/workflows/CI.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
- uses: julia-actions/setup-julia@v1
2121
with:
2222
version: ${{ matrix.version }}
23-
- uses: actions/cache@v3
23+
- uses: actions/cache@v4
2424
env:
2525
cache-name: cache-artifacts
2626
with:
@@ -32,9 +32,11 @@ jobs:
3232
${{ runner.os }}-
3333
- uses: julia-actions/julia-buildpkg@v1
3434
- uses: julia-actions/julia-runtest@v1
35+
with:
36+
annotate: true
3537
env:
3638
GROUP: ${{ matrix.group }}
37-
JULIA_NUM_THREADS: 11
39+
JULIA_NUM_THREADS: "auto"
3840
- uses: julia-actions/julia-processcoverage@v1
3941
- uses: codecov/codecov-action@v3
4042
with:

Project.toml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.1.0"
4+
version = "1.3.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1112
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1213
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -17,23 +18,30 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1718
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1920

20-
[extensions]
21-
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
22-
2321
[weakdeps]
22+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2423
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
24+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
25+
26+
[extensions]
27+
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
28+
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
29+
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
2530

2631
[compat]
2732
ADTypes = "0.2.6"
2833
ArrayInterface = "7.7"
34+
ChainRulesCore = "1"
2935
ConcreteStructs = "0.2.2"
3036
DiffEqBase = "6.144"
37+
FastClosures = "0.3"
3138
FiniteDiff = "2.21"
3239
ForwardDiff = "0.10.36"
3340
LinearAlgebra = "1.9"
3441
MaybeInplace = "0.1.1"
3542
PrecompileTools = "1.2"
3643
Reexport = "1.2"
37-
SciMLBase = "2.11"
44+
SciMLBase = "2.23"
45+
StaticArrays = "1"
3846
StaticArraysCore = "1.4.2"
39-
julia = "1.9"
47+
julia = "1.10"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
[![codecov](https://codecov.io/gh/SciML/SimpleNonlinearSolve.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/SciML/SimpleNonlinearSolve.jl)
77
[![Build Status](https://github.com/SciML/SimpleNonlinearSolve.jl/workflows/CI/badge.svg)](https://github.com/SciML/SimpleNonlinearSolve.jl/actions?query=workflow%3ACI)
8+
[![Build status](https://badge.buildkite.com/c5f7db4f1b5e8a592514378b6fc807d934546cc7d5aa79d645.svg?branch=main)](https://buildkite.com/julialang/simplenonlinearsolve-dot-jl)
89

910
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
1011
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module SimpleNonlinearSolveChainRulesCoreExt
2+
3+
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
4+
5+
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
6+
# eventually lift this requirement using a custom adjoint
7+
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
8+
prob::NonlinearProblem,
9+
sensealg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, u0_changed,
10+
p, p_changed, alg, args...; kwargs...)
11+
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
12+
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
13+
function ∇__internal_solve_up(Δ)
14+
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
15+
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), ∂originator,
16+
∂args...)
17+
end
18+
return out, ∇__internal_solve_up
19+
end
20+
21+
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module SimpleNonlinearSolveStaticArraysExt
2+
3+
using SimpleNonlinearSolve
4+
5+
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true
6+
7+
end

src/SimpleNonlinearSolve.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@ module SimpleNonlinearSolve
33
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes,
7-
ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, SciMLBase
6+
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
7+
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
88

99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
1111
NonlinearSafeTerminationReturnCode, get_termination_mode,
12-
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
13-
using FiniteDiff, ForwardDiff
12+
NONLINEARSOLVE_DEFAULT_NORM
1413
import ForwardDiff: Dual
1514
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
16-
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace
15+
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
1716
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1817
end
1918

@@ -26,6 +25,7 @@ abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm e
2625
@inline __is_extension_loaded(::Val) = false
2726

2827
include("utils.jl")
28+
include("linesearch.jl")
2929

3030
## Nonlinear Solvers
3131
include("nlsolve/raphson.jl")
@@ -50,13 +50,27 @@ include("ad.jl")
5050
## Default algorithm
5151

5252
# Set the default bracketing method to ITP
53-
function SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...)
54-
return solve(prob, ITP(); kwargs...)
53+
SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...) = solve(prob, ITP(); kwargs...)
54+
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; kwargs...)
55+
return solve(prob, ITP(), args...; kwargs...)
5556
end
5657

57-
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing,
58-
args...; kwargs...)
59-
return solve(prob, ITP(), args...; kwargs...)
58+
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
59+
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
60+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
61+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
62+
sensealg = prob.kwargs[:sensealg]
63+
end
64+
new_u0 = u0 !== nothing ? u0 : prob.u0
65+
new_p = p !== nothing ? p : prob.p
66+
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
67+
alg, args...; kwargs...)
68+
end
69+
70+
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
71+
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
72+
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
73+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
6074
end
6175

6276
@setup_workload begin

src/ad.jl

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,43 @@
1-
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2-
f = prob.f
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
2+
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
3+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
4+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
5+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
6+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
7+
sol.original)
8+
end
9+
10+
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
11+
@eval begin
12+
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
13+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
14+
alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip}
15+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
16+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
17+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
18+
sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials),
19+
right = Dual{T, V, P}(sol.right, partials))
20+
end
21+
end
22+
end
23+
24+
function __nlsolve_ad(prob, alg, args...; kwargs...)
325
p = value(prob.p)
426
if prob isa IntervalNonlinearProblem
527
tspan = value.(prob.tspan)
6-
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
28+
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
729
else
830
u0 = value(prob.u0)
9-
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
31+
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
1032
end
1133

1234
sol = solve(newprob, alg, args...; kwargs...)
1335

1436
uu = sol.u
15-
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
16-
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
37+
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
38+
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)
1739

18-
z_arr = -inv(f_x) * f_p
40+
z_arr = -f_x \ f_p
1941

2042
pp = prob.p
2143
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
@@ -30,60 +52,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3052
return sol, partials
3153
end
3254

33-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
34-
false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
35-
kwargs...) where {T, V, P}
36-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
37-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
38-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
39-
end
40-
41-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
42-
false, <:AbstractArray{<:Dual{T, V, P}}},
43-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P}
44-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
45-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
46-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
47-
end
48-
49-
function scalar_nlsolve_∂f_∂p(f, u, p)
50-
ff = p isa Number ? ForwardDiff.derivative :
51-
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
52-
return ff(Base.Fix1(f, u), p)
55+
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
56+
if isinplace(prob)
57+
__f = p -> begin
58+
du = similar(u, promote_type(eltype(u), eltype(p)))
59+
f(du, u, p)
60+
return du
61+
end
62+
else
63+
__f = Base.Fix1(f, u)
64+
end
65+
if p isa Number
66+
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
67+
elseif u isa Number
68+
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
69+
else
70+
return ForwardDiff.jacobian(__f, p)
71+
end
5372
end
5473

55-
function scalar_nlsolve_∂f_∂u(f, u, p)
56-
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
57-
return ff(Base.Fix2(f, p), u)
74+
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
75+
if isinplace(prob)
76+
du = similar(u)
77+
__f = (du, u) -> f(du, u, p)
78+
ForwardDiff.jacobian(__f, du, u)
79+
else
80+
__f = Base.Fix2(f, p)
81+
if u isa Number
82+
return ForwardDiff.derivative(__f, u)
83+
else
84+
return ForwardDiff.jacobian(__f, u)
85+
end
86+
end
5887
end
5988

60-
function scalar_nlsolve_dual_soln(u::Number, partials,
89+
@inline function __nlsolve_dual_soln(u::Number, partials,
6190
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
6291
return Dual{T, V, P}(u, partials)
6392
end
6493

65-
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
94+
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
6695
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
67-
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
68-
end
69-
70-
# avoid ambiguities
71-
for Alg in [Bisection]
72-
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
73-
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
74-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
75-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
76-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
77-
left = Dual{T, V, P}(sol.left, partials),
78-
right = Dual{T, V, P}(sol.right, partials))
79-
end
80-
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
81-
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
82-
kwargs...) where {uType, iip, T, V, P}
83-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
84-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
85-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
86-
left = Dual{T, V, P}(sol.left, partials),
87-
right = Dual{T, V, P}(sol.right, partials))
88-
end
96+
_partials = _restructure(u, partials)
97+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
8998
end

src/bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = _get_tolerance(abstol,
29+
abstol = __get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = _get_tolerance(abstol,
16+
abstol = __get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

0 commit comments

Comments
 (0)