Skip to content

Commit 651d56e

Browse files
format
1 parent 4bd17e1 commit 651d56e

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

src/NonlinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
4444
for alg in (NewtonRaphson,)
4545
solve(prob, alg(), abstol = T(1e-2))
4646
end
47-
4847
end end
4948

5049
export NewtonRaphson

src/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2323
return sol, partials
2424
end
2525

26-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
26+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
27+
iip,
2728
<:Dual{T, V, P}}, alg::NewtonRaphson,
2829
args...; kwargs...) where {iip, T, V, P}
2930
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3031
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
3132
retcode = sol.retcode)
3233
end
33-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
34+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
35+
iip,
3436
<:AbstractArray{<:Dual{T, V, P}}},
3537
alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
3638
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function perform_step!(cache::NewtonRaphsonCache{true})
120120

121121
# u = u - J \ fu
122122
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
123-
p = p, reltol = cache.abstol)
123+
p = p, reltol = cache.abstol)
124124
cache.linsolve = linres.cache
125125
@. u = u - du1
126126
f(fu, u, p)

src/utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
55
sqrt(real(sum(abs2, u)) / length(u))
66
end
7-
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
7+
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {
8+
T <: Union{
9+
AbstractFloat,
10+
Complex}}
811
sqrt(real(sum(abs2, u)) / length(u))
912
end
1013
@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray)
@@ -28,7 +31,9 @@ function value_derivative(f::F, x::R) where {F, R}
2831
end
2932

3033
# Todo: improve this dispatch
31-
value_derivative(f::F, x::StaticArraysCore.SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)
34+
function value_derivative(f::F, x::StaticArraysCore.SVector) where {F}
35+
f(x), ForwardDiff.jacobian(f, x)
36+
end
3237

3338
value(x) = x
3439
value(x::Dual) = ForwardDiff.value(x)

0 commit comments

Comments
 (0)