Skip to content

Commit 152208e

Browse files
committed
Start implementing trimming test
We introduce `SafeTestsets` as part of this, inspired by the way the downstream tests are set up in the `SciMLBase` repo.
1 parent b11983c commit 152208e

File tree

5 files changed

+129
-10
lines changed

5 files changed

+129
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
143143
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
144144
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
145145
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
146+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
146147
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
147148
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
148149
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -170,5 +171,4 @@ path = "lib/NonlinearSolveSpectralMethods"
170171
path = "lib/SimpleNonlinearSolve"
171172

172173
[targets]
173-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
174-
174+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SafeTestsets", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]

test/runtests.jl

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1-
using ReTestItems, NonlinearSolve, Hwloc, InteractiveUtils, Pkg
1+
using NonlinearSolve, Hwloc, InteractiveUtils, Pkg
2+
using SafeTestsets
3+
using ReTestItems
24

35
@info sprint(InteractiveUtils.versioninfo)
46

57
const GROUP = lowercase(get(ENV, "GROUP", "All"))
68

9+
function activate_trim_env!()
10+
Pkg.activate(abspath(joinpath(dirname(@__FILE__), "trim")))
11+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
12+
Pkg.instantiate()
13+
return nothing
14+
end
15+
716
const EXTRA_PKGS = Pkg.PackageSpec[]
817
if GROUP == "all" || GROUP == "downstream"
918
push!(EXTRA_PKGS, Pkg.PackageSpec("ModelingToolkit"))
@@ -12,11 +21,13 @@ end
1221
length(EXTRA_PKGS) 1 && Pkg.add(EXTRA_PKGS)
1322

1423
const RETESTITEMS_NWORKERS = parse(
15-
Int, get(ENV, "RETESTITEMS_NWORKERS",
24+
Int, get(
25+
ENV, "RETESTITEMS_NWORKERS",
1626
string(min(ifelse(Sys.iswindows(), 0, Hwloc.num_physical_cores()), 4))
1727
)
1828
)
19-
const RETESTITEMS_NWORKER_THREADS = parse(Int,
29+
const RETESTITEMS_NWORKER_THREADS = parse(
30+
Int,
2031
get(
2132
ENV, "RETESTITEMS_NWORKER_THREADS",
2233
string(max(Hwloc.num_virtual_cores() ÷ max(RETESTITEMS_NWORKERS, 1), 1))
@@ -25,8 +36,22 @@ const RETESTITEMS_NWORKER_THREADS = parse(Int,
2536

2637
@info "Running tests for group: $(GROUP) with $(RETESTITEMS_NWORKERS) workers"
2738

28-
ReTestItems.runtests(
29-
NonlinearSolve; tags = (GROUP == "all" ? nothing : [Symbol(GROUP)]),
30-
nworkers = RETESTITEMS_NWORKERS, nworker_threads = RETESTITEMS_NWORKER_THREADS,
31-
testitem_timeout = 3600
32-
)
39+
if GROUP != "trim"
40+
ReTestItems.runtests(
41+
NonlinearSolve; tags = (GROUP == "all" ? nothing : [Symbol(GROUP)]),
42+
nworkers = RETESTITEMS_NWORKERS, nworker_threads = RETESTITEMS_NWORKER_THREADS,
43+
testitem_timeout = 3600
44+
)
45+
elseif GROUP == "trim" && VERSION >= v"1.12.0-rc1" # trimming has been introduced in julia 1.12
46+
activate_trim_env!()
47+
@safetestset "Clean implementation (non-trimmable)" begin
48+
using SciMLBase: successful_retcode
49+
include("trim/clean_optimization.jl")
50+
@test successful_retcode(minimize(1.0).retcode)
51+
end
52+
@safetestset "Trimmable implementation" begin
53+
using SciMLBase: successful_retcode
54+
include("trim/trimmable_optimization.jl")
55+
@test successful_retcode(minimize(1.0).retcode)
56+
end
57+
end

test/trim/Project.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
7+
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
8+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
9+
PolyesterWeave = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad"
10+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
11+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
14+
[sources]
15+
ForwardDiff = {rev = "rv/remove-quote-assert-string-interpolation", url = "https://github.com/RomeoV/ForwardDiff.jl"}
16+
LinearSolve = {rev = "rv/remove-linsolve-forwarddiff-special-path", url = "https://github.com/RomeoV/LinearSolve.jl"}
17+
NonlinearSolveFirstOrder = {path = "../../lib/NonlinearSolveFirstOrder"}
18+
Polyester = {rev = "master", url = "https://github.com/RomeoV/Polyester.jl"}
19+
PolyesterWeave = {rev = "main", url = "https://github.com/RomeoV/PolyesterWeave.jl"}
20+
SciMLBase = {rev = "as/fix-jet-opt", url = "https://github.com/AayushSabharwal/SciMLBase.jl"}
21+
22+
[compat]
23+
ADTypes = "1.15.0"
24+
DiffEqBase = "6.179.0"
25+
LinearAlgebra = "1.12.0"
26+
NonlinearSolveFirstOrder = "1.6.0"
27+
StaticArrays = "1.9.0"

test/trim/clean_optimization.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using NonlinearSolveFirstOrder
2+
using ADTypes: AutoForwardDiff
3+
using ForwardDiff
4+
using LinearAlgebra
5+
using StaticArrays
6+
using LinearSolve
7+
const LS = LinearSolve
8+
9+
function f(u, p)
10+
L, U = cholesky(p.Σ)
11+
rhs = (u .* u .- p.λ)
12+
# there are some issues currently with LinearSolve and triangular matrices,
13+
# so we just make `L` dense here.
14+
linprob = LinearProblem(Matrix(L), rhs)
15+
alg = LS.GenericLUFactorization()
16+
sol = LinearSolve.solve(linprob, alg)
17+
return sol.u
18+
end
19+
20+
struct MyParams{T, M}
21+
λ::T
22+
Σ::M
23+
end
24+
25+
function minimize(x)
26+
autodiff = AutoForwardDiff(; chunksize=1)
27+
alg = TrustRegion(; autodiff, linsolve=LS.CholeskyFactorization())
28+
ps = MyParams(rand(), hermitianpart(rand(2,2)+2I))
29+
prob = NonlinearLeastSquaresProblem{false}(f, rand(2), ps)
30+
sol = solve(prob, alg)
31+
return sol
32+
end
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using NonlinearSolveFirstOrder
2+
using ADTypes: AutoForwardDiff
3+
using ForwardDiff
4+
using LinearAlgebra
5+
using StaticArrays
6+
using LinearSolve
7+
const LS = LinearSolve
8+
9+
function f(u, p)
10+
L, U = cholesky(p.Σ)
11+
rhs = (u .* u .- p.λ)
12+
# there are some issues currently with LinearSolve and triangular matrices,
13+
# so we just make `L` dense here.
14+
linprob = LinearProblem(Matrix(L), rhs)
15+
alg = LS.GenericLUFactorization()
16+
sol = LinearSolve.solve(linprob, alg)
17+
return sol.u
18+
end
19+
20+
struct MyParams{T, M}
21+
λ::T
22+
Σ::M
23+
end
24+
25+
const autodiff = AutoForwardDiff(; chunksize = 1)
26+
const alg = TrustRegion(; autodiff, linsolve = LS.CholeskyFactorization())
27+
const prob = NonlinearLeastSquaresProblem{false}(f, rand(2), MyParams(rand(), hermitianpart(rand(2, 2) + 2I)))
28+
const cache = init(prob, alg)
29+
30+
function minimize(x)
31+
ps = MyParams(x, hermitianpart(rand(2, 2) + 2I))
32+
reinit!(cache, rand(2); p = ps)
33+
solve!(cache)
34+
return cache
35+
end

0 commit comments

Comments
 (0)