Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit ea20ad6

Browse files
Merge pull request #91 from SciML/least_squares
Add a dispatch to SimpleNewtonRaphson for NNLS and SimpleGaussNewton
2 parents 0d73574 + cf03317 commit ea20ad6

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ PrecompileTools.@compile_workload begin
8989
end
9090

9191
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, SimpleHalley, Klement,
92-
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP
92+
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleGaussNewton
9393
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane
9494

9595
end # module

src/raphson.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ function SimpleNewtonRaphson(; batched = false,
6161
SciMLBase._unwrap_val(diff_type)}()
6262
end
6363

64-
function SciMLBase.__solve(prob::NonlinearProblem,
64+
const SimpleGaussNewton = SimpleNewtonRaphson
65+
66+
function SciMLBase.__solve(prob::Union{NonlinearProblem,NonlinearLeastSquaresProblem},
6567
alg::SimpleNewtonRaphson, args...; abstol = nothing,
6668
reltol = nothing,
6769
maxiters = 1000, kwargs...)
@@ -74,6 +76,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
7476
error("SimpleNewtonRaphson currently only supports out-of-place nonlinear problems")
7577
end
7678

79+
if prob isa NonlinearLeastSquaresProblem && !(typeof(prob.u0) <: Union{Number, AbstractVector})
80+
error("SimpleGaussNewton only supports Number and AbstactVector types. Please convert any problem of AbstractArray into one with u0 as AbstractVector")
81+
end
82+
7783
atol = abstol !== nothing ? abstol :
7884
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
7985
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
@@ -100,7 +106,13 @@ function SciMLBase.__solve(prob::NonlinearProblem,
100106
end
101107
iszero(fx) &&
102108
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
103-
Δx = _restructure(fx, dfx \ _vec(fx))
109+
110+
if prob isa NonlinearProblem
111+
Δx = _restructure(fx, dfx \ _vec(fx))
112+
else
113+
Δx = dfx \ fx
114+
end
115+
104116
x -= Δx
105117
if isapprox(x, xo, atol = atol, rtol = rtol)
106118
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

test/least_squares.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using SimpleNonlinearSolve, LinearAlgebra, Test
2+
3+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
4+
5+
θ_true = [1.0, 0.1, 2.0, 0.5]
6+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
7+
y_target = true_function(x, θ_true)
8+
9+
function loss_function(θ, p)
10+
= true_function(p, θ)
11+
return abs2.(ŷ .- y_target)
12+
end
13+
14+
θ_init = θ_true .+ 0.1
15+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
16+
sol = solve(prob_oop, SimpleNewtonRaphson())
17+
sol = solve(prob_oop, SimpleGaussNewton())
18+
19+
@test norm(sol.resid) < 1e-12

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ const GROUP = get(ENV, "GROUP", "All")
77
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
88
@time @safetestset "Inplace Tests" include("inplace.jl")
99
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
10+
@time @safetestset "Least Squares Tests" include("least_squares.jl")
1011
end
1112
end

0 commit comments

Comments
 (0)