Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c320b75

Browse files
committed
Resolve tolerances differently for kernel launches
1 parent e1e1be8 commit c320b75

File tree

10 files changed

+44
-8
lines changed

10 files changed

+44
-8
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
1111
NonlinearSafeTerminationReturnCode, get_termination_mode,
12-
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
12+
NONLINEARSOLVE_DEFAULT_NORM
1313
import ForwardDiff: Dual
1414
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
1515
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
@@ -65,6 +65,15 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
6565
return SciMLBase.__solve(prob, alg, args...; kwargs...)
6666
end
6767

68+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{<:Number, <:SArray}},
69+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; abstol = nothing,
70+
reltol = nothing, kwargs...)
71+
_abstol = __get_tolerance(prob.u0, abstol, eltype(prob.u0))
72+
_reltol = __get_tolerance(prob.u0, reltol, eltype(prob.u0))
73+
return SciMLBase.__solve(prob, alg, args...; abstol = _abstol, reltol = _reltol,
74+
kwargs...)
75+
end
76+
6877
@setup_workload begin
6978
for T in (Float32, Float64)
7079
prob_no_brack_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

src/ad.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,20 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
88
end
99

1010
# Handle Ambiguities
11+
for algType in (SimpleNewtonRaphson, SimpleDFSane, SimpleTrustRegion, SimpleBroyden,
12+
SimpleLimitedMemoryBroyden, SimpleKlement, SimpleHalley)
13+
@eval begin
14+
function SciMLBase.solve(prob::NonlinearProblem{uType, iip,
15+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
16+
alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip}
17+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
18+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
19+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
20+
sol.stats, sol.original)
21+
end
22+
end
23+
end
24+
1125
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
1226
@eval begin
1327
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,

src/bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = _get_tolerance(abstol,
29+
abstol = __get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = _get_tolerance(abstol,
16+
abstol = __get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

src/bracketing/falsi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/bracketing/itp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
5858
left, right = prob.tspan
5959
fl, fr = f(left), f(right)
6060

61-
abstol = _get_tolerance(abstol,
61+
abstol = __get_tolerance(nothing, abstol,
6262
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
6363

6464
if iszero(fl)

src/bracketing/ridder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/nlsolve/lbroyden.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27),
2828
return SimpleLimitedMemoryBroyden{_unwrap_val(threshold), _unwrap_val(linesearch)}()
2929
end
3030

31+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{<:Number, <:SArray}},
32+
alg::SimpleLimitedMemoryBroyden, args...; kwargs...)
33+
# Don't resolve the `abstol` and `reltol` here
34+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
35+
end
36+
3137
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
3238
args...; termination_condition = nothing, kwargs...)
3339
if prob.u0 isa SArray
@@ -120,7 +126,7 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo
120126

121127
U, Vᵀ = __init_low_rank_jacobian(vec(x), vec(fx), threshold)
122128

123-
abstol = DiffEqBase._get_tolerance(abstol, eltype(x))
129+
abstol = __get_tolerance(x, abstol, eltype(x))
124130

125131
xo, δx, fo, δf = x, -fx, fx, fx
126132

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,10 @@ end
370370

371371
@inline __reshape(x::Number, args...) = x
372372
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)
373+
374+
# Override cases which might be used in a kernel launch
375+
__get_tolerance(x, η, ::Type{T}) where {T} = DiffEqBase._get_tolerance(η, T)
376+
function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {T}
377+
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
378+
return T(η)
379+
end

test/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ f!(du, u, p) = du .= u .* u .- 2
3737
end
3838

3939
function kernel_function(prob, alg)
40-
solve(prob, alg; abstol = 1.0f-6, reltol = 1.0f-6)
40+
solve(prob, alg)
4141
return nothing
4242
end
4343

0 commit comments

Comments
 (0)