Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c73f560

Browse files
committed
Fix tests
1 parent fadbffa commit c73f560

File tree

3 files changed

+30
-42
lines changed

3 files changed

+30
-42
lines changed

src/ad.jl

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,30 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
77
sol.original)
88
end
99

10-
function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...;
11-
kwargs...) where {uType, iip}
10+
# Handle Ambiguities
11+
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
12+
@eval begin
13+
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
14+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
15+
alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip}
16+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
17+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
18+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
19+
sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials),
20+
right = Dual{T, V, P}(sol.right, partials))
21+
end
22+
end
23+
end
24+
25+
function __nlsolve_ad(prob, alg, args...; kwargs...)
1226
p = value(prob.p)
13-
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
27+
if prob isa IntervalNonlinearProblem
28+
tspan = value.(prob.tspan)
29+
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
30+
else
31+
u0 = value(prob.u0)
32+
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
33+
end
1434

1535
sol = solve(newprob, alg, args...; kwargs...)
1636

@@ -77,24 +97,3 @@ end
7797
_partials = _restructure(u, partials)
7898
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
7999
end
80-
81-
# avoid ambiguities
82-
for Alg in [Bisection]
83-
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
84-
<:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
85-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
86-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
87-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
88-
left = Dual{T, V, P}(sol.left, partials),
89-
right = Dual{T, V, P}(sol.right, partials))
90-
end
91-
@eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
92-
<:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...;
93-
kwargs...) where {uType, iip, T, V, P}
94-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
95-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
96-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode,
97-
left = Dual{T, V, P}(sol.left, partials),
98-
right = Dual{T, V, P}(sol.right, partials))
99-
end
100-
end

src/nlsolve/halley.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
7171

7272
@bb Aaᵢ = A × aᵢ
7373
@bb A .*= -1
74-
bᵢ = dfx_fact \ Aaᵢ
74+
bᵢ = dfx_fact \ _vec(Aaᵢ)
7575

7676
cᵢ_ = _vec(cᵢ)
7777
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))

test/forward_ad.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra
2+
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
23

34
test_f!(du, u, p) = (@. du = u^2 - p)
45
test_f(u, p) = (@. u^2 - p)
@@ -26,24 +27,12 @@ __compatible(::Any, ::Number) = true
2627
__compatible(::Number, ::AbstractArray) = false
2728
__compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p)
2829

29-
__compatible(u::Number, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) = true
30-
function __compatible(u::AbstractArray,
31-
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
32-
true
33-
end
34-
function __compatible(u::StaticArray,
35-
::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm)
36-
true
37-
end
30+
__compatible(u::Number, ::AbstractSimpleNonlinearSolveAlgorithm) = true
31+
__compatible(u::AbstractArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true
32+
__compatible(u::StaticArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true
3833

39-
function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
40-
::Val{:iip})
41-
true
42-
end
43-
function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
44-
::Val{:oop})
45-
true
46-
end
34+
__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:iip}) = true
35+
__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:oop}) = true
4736
__compatible(::SimpleHalley, ::Val{:iip}) = false
4837

4938
@testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),

0 commit comments

Comments
 (0)