Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit ec64612

Browse files
committed
Allow allocations for the termination condition version
1 parent c071dd3 commit ec64612

File tree

5 files changed

+85
-12
lines changed

5 files changed

+85
-12
lines changed

.buildkite/pipeline.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
steps:
2+
- label: "Julia 1"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1"
6+
agents:
7+
queue: "juliagpu"
8+
cuda: "*"
9+
timeout_in_minutes: 30
10+
# Don't run Buildkite if the commit message includes the text [skip tests]
11+
if: build.message !~ /\[skip tests\]/
12+
13+
env:
14+
GROUP: CUDA
15+
JULIA_PKG_SERVER: ""

src/nlsolve/lbroyden.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,22 @@ function SimpleLimitedMemoryBroyden(; threshold::Union{Val, Int} = Val(27))
2121
return SimpleLimitedMemoryBroyden{SciMLBase._unwrap_val(threshold)}()
2222
end
2323

24-
@views function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
24+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
25+
args...; termination_condition = nothing, kwargs...)
26+
if prob.u0 isa SArray
27+
if termination_condition === nothing ||
28+
termination_condition isa AbsNormTerminationMode
29+
return __static_solve(prob, alg, args...; termination_condition, kwargs...)
30+
end
31+
@warn "Specifying `termination_condition = $(termination_condition)` for \
32+
`SimpleLimitedMemoryBroyden` with `SArray` is not non-allocating. Use \
33+
either `termination_condition = AbsNormTerminationMode()` or \
34+
`termination_condition = nothing`." maxlog=1
35+
end
36+
return __generic_solve(prob, alg, args...; termination_condition, kwargs...)
37+
end
38+
39+
@views function __generic_solve(prob::NonlinearProblem, alg::SimpleLimitedMemoryBroyden,
2540
args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
2641
termination_condition = nothing, kwargs...)
2742
x = __maybe_unaliased(prob.u0, alias_u0)
@@ -85,17 +100,9 @@ end
85100

86101
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
87102
# finicky, so we'll implement it separately from the generic version
88-
# We make an exception here and don't support termination conditions
89-
@views function SciMLBase.__solve(prob::NonlinearProblem{<:SArray},
90-
alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing,
91-
termination_condition = nothing,
92-
maxiters = 1000, kwargs...)
93-
if termination_condition !== nothing &&
94-
!(termination_condition isa AbsNormTerminationMode)
95-
error("SimpleLimitedMemoryBroyden with StaticArrays does not support termination \
96-
conditions!")
97-
end
98-
103+
# Ignore termination_condition. Don't pass things into internal functions
104+
function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemoryBroyden,
105+
args...; abstol = nothing, maxiters = 1000, kwargs...)
99106
x = prob.u0
100107
fx = _get_fx(prob, x)
101108
threshold = __get_threshold(alg)

test/cuda.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using SimpleNonlinearSolve, StaticArrays, CUDA, Test
2+
3+
CUDA.allowscalar(false)
4+
5+
f(u, p) = u .* u .- 2
6+
f!(du, u, p) = du .= u .* u .- 2
7+
8+
@testset "Solving on GPUs" begin
9+
for alg in (SimpleNewtonRaphson(), SimpleDFSane(), SimpleTrustRegion(), SimpleBroyden(),
10+
SimpleLimitedMemoryBroyden(), SimpleKlement(), SimpleHalley())
11+
@info "Testing $alg on CUDA"
12+
13+
# Static Arrays
14+
u0 = @SVector[1.0f0, 1.0f0]
15+
probN = NonlinearProblem{false}(f, u0)
16+
sol = solve(probN, alg; abstol = 1.0f-6)
17+
@test SciMLBase.successful_retcode(sol)
18+
@test maximum(abs, sol.resid) 1.0f-6
19+
20+
# Regular Arrays
21+
u0 = [1.0, 1.0]
22+
probN = NonlinearProblem{false}(f, u0)
23+
sol = solve(probN, alg; abstol = 1.0f-6)
24+
@test SciMLBase.successful_retcode(sol)
25+
@test maximum(abs, sol.resid) 1.0f-6
26+
27+
# Regular Arrays Inplace
28+
alg isa SimpleHalley && continue
29+
u0 = [1.0, 1.0]
30+
probN = NonlinearProblem{true}(f!, u0)
31+
sol = solve(probN, alg; abstol = 1.0f-6)
32+
@test SciMLBase.successful_retcode(sol)
33+
@test maximum(abs, sol.resid) 1.0f-6
34+
end
35+
end

test/cuda/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
4+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ using SafeTestsets, Test
22

33
const GROUP = get(ENV, "GROUP", "All")
44

5+
function activate_env(env)
6+
Pkg.activate(env)
7+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
8+
Pkg.instantiate()
9+
end
10+
511
@time @testset "SimpleNonlinearSolve.jl" begin
612
if GROUP == "All" || GROUP == "Core"
713
@time @safetestset "Basic Tests" include("basictests.jl")
@@ -10,4 +16,9 @@ const GROUP = get(ENV, "GROUP", "All")
1016
@time @safetestset "Least Squares Tests" include("least_squares.jl")
1117
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
1218
end
19+
20+
if GROUP == "CUDA"
21+
activate_env("cuda")
22+
@time @safetestset "CUDA Tests" include("cuda.jl")
23+
end
1324
end

0 commit comments

Comments
 (0)