Skip to content

Commit cbdaff1

Browse files
Merge pull request #99 from SciML/general_fixes
more general fixes
2 parents 37ef686 + 651d56e commit cbdaff1

File tree

6 files changed

+60
-140
lines changed

6 files changed

+60
-140
lines changed

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1313
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1515
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
16-
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1716
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
18-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
17+
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
18+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1919
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2020

2121
[compat]
@@ -27,9 +27,9 @@ LinearSolve = "1"
2727
RecursiveArrayTools = "2"
2828
Reexport = "0.2, 1"
2929
SciMLBase = "1.73"
30-
Setfield = "0.7, 0.8, 1"
3130
SimpleNonlinearSolve = "0.1"
32-
StaticArrays = "0.12,1.0"
31+
SnoopPrecompile = "1"
32+
StaticArraysCore = "1.4"
3333
UnPack = "1.0"
3434
julia = "1.6"
3535

@@ -38,7 +38,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3838
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3939
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4040
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
41+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4142
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4243

4344
[targets]
44-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff"]
45+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays"]

src/NonlinearSolve.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ using Reexport
44
using UnPack: @unpack
55
using FiniteDiff, ForwardDiff
66
using ForwardDiff: Dual
7-
using Setfield
8-
using StaticArrays
9-
using RecursiveArrayTools
107
using LinearAlgebra
8+
using StaticArraysCore
9+
using RecursiveArrayTools
1110
import ArrayInterfaceCore
1211
import LinearSolve
1312
using DiffEqBase
1413

1514
@reexport using SciMLBase
1615
@reexport using SimpleNonlinearSolve
1716

17+
import SciMLBase: _unwrap_val
18+
1819
abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
1920
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
2021
AbstractNonlinearSolveAlgorithm end
@@ -31,6 +32,20 @@ include("jacobian.jl")
3132
include("raphson.jl")
3233
include("ad.jl")
3334

35+
import SnoopPrecompile
36+
37+
SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
38+
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
39+
for alg in (NewtonRaphson,)
40+
solve(prob, alg(), abstol = T(1e-2))
41+
end
42+
43+
prob = NonlinearProblem{true}((du, u, p) -> du[1] = u[1] * u[1] - p[1], T[0.1], T[2])
44+
for alg in (NewtonRaphson,)
45+
solve(prob, alg(), abstol = T(1e-2))
46+
end
47+
end end
48+
3449
export NewtonRaphson
3550

3651
end # module

src/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2323
return sol, partials
2424
end
2525

26-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
26+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
27+
iip,
2728
<:Dual{T, V, P}}, alg::NewtonRaphson,
2829
args...; kwargs...) where {iip, T, V, P}
2930
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3031
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
3132
retcode = sol.retcode)
3233
end
33-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
34+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
35+
iip,
3436
<:AbstractArray{<:Dual{T, V, P}}},
3537
alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
3638
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

src/raphson.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,14 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
114114
end
115115

116116
function perform_step!(cache::NewtonRaphsonCache{true})
117-
@unpack u, fu, f, p, cache = cache
117+
@unpack u, fu, f, p, alg = cache
118118
@unpack J, linsolve, du1 = cache
119119
calc_J!(J, cache, cache)
120+
120121
# u = u - J \ fu
121-
linsolve = dolinsolve(alg.precs, linsolve, A = J, b = fu, u = du1,
122-
p = p, reltol = cache.abstol)
123-
cache.linsolve = linsolve
122+
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
123+
p = p, reltol = cache.abstol)
124+
cache.linsolve = linres.cache
124125
@. u = u - du1
125126
f(fu, u, p)
126127

@@ -150,6 +151,8 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache)
150151

151152
if cache.iter == cache.maxiters
152153
cache.retcode = ReturnCode.MaxIters
154+
else
155+
cache.retcode = ReturnCode.Success
153156
end
154157

155158
SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu;

src/utils.jl

Lines changed: 7 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,13 @@
1-
"""
2-
@add_kwonly function_definition
3-
4-
Define keyword-only version of the `function_definition`.
5-
6-
@add_kwonly function f(x; y=1)
7-
...
8-
end
9-
10-
expands to:
11-
12-
function f(x; y=1)
13-
...
14-
end
15-
function f(; x = error("No argument x"), y=1)
16-
...
17-
end
18-
"""
19-
macro add_kwonly(ex)
20-
esc(add_kwonly(ex))
21-
end
22-
23-
add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex)
24-
25-
function add_kwonly(::Type{<:Val}, ex)
26-
error("add_only does not work with expression $(ex.head)")
27-
end
28-
29-
function add_kwonly(::Union{Type{Val{:function}},
30-
Type{Val{:(=)}}}, ex::Expr)
31-
body = ex.args[2:end] # function body
32-
default_call = ex.args[1] # e.g., :(f(a, b=2; c=3))
33-
kwonly_call = add_kwonly(default_call)
34-
if kwonly_call === nothing
35-
return ex
36-
end
37-
38-
return quote
39-
begin
40-
$ex
41-
$(Expr(ex.head, kwonly_call, body...))
42-
end
43-
end
44-
end
45-
46-
function add_kwonly(::Type{Val{:where}}, ex::Expr)
47-
default_call = ex.args[1]
48-
rest = ex.args[2:end]
49-
kwonly_call = add_kwonly(default_call)
50-
if kwonly_call === nothing
51-
return nothing
52-
end
53-
return Expr(:where, kwonly_call, rest...)
54-
end
55-
56-
function add_kwonly(::Type{Val{:call}}, default_call::Expr)
57-
# default_call is, e.g., :(f(a, b=2; c=3))
58-
funcname = default_call.args[1] # e.g., :f
59-
required = [] # required positional arguments; e.g., [:a]
60-
optional = [] # optional positional arguments; e.g., [:(b=2)]
61-
default_kwargs = []
62-
for arg in default_call.args[2:end]
63-
if isa(arg, Symbol)
64-
push!(required, arg)
65-
elseif arg.head == :(::)
66-
push!(required, arg)
67-
elseif arg.head == :kw
68-
push!(optional, arg)
69-
elseif arg.head == :parameters
70-
@assert default_kwargs == [] # can I have :parameters twice?
71-
default_kwargs = arg.args
72-
else
73-
error("Not expecting to see: $arg")
74-
end
75-
end
76-
if isempty(required) && isempty(optional)
77-
# If the function is already keyword-only, do nothing:
78-
return nothing
79-
end
80-
if isempty(required)
81-
# It's not clear what should be done. Let's not support it at
82-
# the moment:
83-
error("At least one positional mandatory argument is required.")
84-
end
85-
86-
kwonly_kwargs = Expr(:parameters,
87-
[Expr(:kw, pa, :(error($("No argument $pa"))))
88-
for pa in required]..., optional..., default_kwargs...)
89-
kwonly_call = Expr(:call, funcname, kwonly_kwargs)
90-
# e.g., :(f(; a=error(...), b=error(...), c=1, d=2))
91-
92-
return kwonly_call
93-
end
94-
95-
function num_types_in_tuple(sig)
96-
length(sig.parameters)
97-
end
98-
99-
function num_types_in_tuple(sig::UnionAll)
100-
length(Base.unwrap_unionall(sig).parameters)
101-
end
1021

1032
@inline UNITLESS_ABS2(x) = real(abs2(x))
1043
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
1054
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
1065
sqrt(real(sum(abs2, u)) / length(u))
1076
end
108-
@inline function DEFAULT_NORM(u::StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
7+
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {
8+
T <: Union{
9+
AbstractFloat,
10+
Complex}}
10911
sqrt(real(sum(abs2, u)) / length(u))
11012
end
11113
@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray)
@@ -114,23 +16,6 @@ end
11416
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
11517
@inline DEFAULT_NORM(u) = norm(u)
11618

117-
"""
118-
prevfloat_tdir(x, x0, x1)
119-
120-
Move `x` one floating point towards x0.
121-
"""
122-
function prevfloat_tdir(x, x0, x1)
123-
x1 > x0 ? prevfloat(x) : nextfloat(x)
124-
end
125-
126-
function nextfloat_tdir(x, x0, x1)
127-
x1 > x0 ? nextfloat(x) : prevfloat(x)
128-
end
129-
130-
function max_tdir(a, b, x0, x1)
131-
x1 > x0 ? max(a, b) : min(a, b)
132-
end
133-
13419
alg_autodiff(alg::AbstractNewtonAlgorithm{CS, AD}) where {CS, AD} = AD
13520
alg_autodiff(alg) = false
13621

@@ -146,15 +31,14 @@ function value_derivative(f::F, x::R) where {F, R}
14631
end
14732

14833
# Todo: improve this dispatch
149-
value_derivative(f::F, x::SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)
34+
function value_derivative(f::F, x::StaticArraysCore.SVector) where {F}
35+
f(x), ForwardDiff.jacobian(f, x)
36+
end
15037

15138
value(x) = x
15239
value(x::Dual) = ForwardDiff.value(x)
15340
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
15441

155-
_unwrap_val(::Val{B}) where {B} = B
156-
_unwrap_val(B) = B
157-
15842
_vec(v) = vec(v)
15943
_vec(v::Number) = v
16044
_vec(v::AbstractVector) = v

test/basictests.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,34 @@ end
3030
const csu0 = 1.0
3131

3232
sol = benchmark_immutable(ff, cu0)
33-
@test sol.retcode === ReturnCode.Default
33+
@test sol.retcode === ReturnCode.Success
3434
@test all(sol.u .* sol.u .- 2 .< 1e-9)
3535
sol = benchmark_mutable(ff, cu0)
36-
@test sol.retcode === ReturnCode.Default
36+
@test sol.retcode === ReturnCode.Success
3737
@test all(sol.u .* sol.u .- 2 .< 1e-9)
3838
sol = benchmark_scalar(sf, csu0)
39-
@test sol.retcode === ReturnCode.Default
39+
@test sol.retcode === ReturnCode.Success
4040
@test sol.u * sol.u - 2 < 1e-9
4141

4242
@test (@ballocated benchmark_immutable(ff, cu0)) < 200
4343
@test (@ballocated benchmark_mutable(ff, cu0)) < 200
4444
@test (@ballocated benchmark_scalar(sf, csu0)) < 400
4545

46+
function benchmark_inplace(f, u0)
47+
probN = NonlinearProblem{true}(f, u0)
48+
solver = init(probN, NewtonRaphson(), abstol = 1e-9)
49+
sol = solve!(solver)
50+
end
51+
52+
function ffiip(du, u, p)
53+
du .= u .* u .- 2
54+
end
55+
u0 = [1.0, 1.0]
56+
57+
sol = benchmark_inplace(ffiip, u0)
58+
@test sol.retcode === ReturnCode.Success
59+
@test all(sol.u .* sol.u .- 2 .< 1e-9)
60+
4661
# AD Tests
4762
using ForwardDiff
4863

0 commit comments

Comments
 (0)