Skip to content

Commit 070ce23

Browse files
authored
Merge pull request #31 from TensorBFS/inference-tests-refactoring
Refactor inference tests
2 parents 02cc203 + 0abfbf8 commit 070ce23

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

test/inference.jl

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ using TensorInference
99
op = ein"ij, j -> i"
1010
@test Array(x) exp(2.0) .* [2.0, 3.0]
1111
@test op(Array(A), Array(x)) Array(op(A, x))
12-
println(x)
1312
end
1413

1514
@testset "cached, rescaled contract" begin
1615
problem = read_uai_problem("Promedus_14")
16+
ref_sol = problem.reference_marginals
1717
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)
1818
tn = TensorNetworkModel(problem; optimizer)
1919
p1 = probability(tn; usecuda = false, rescale = false)
@@ -28,60 +28,55 @@ end
2828
@test Array(cache.content) p1
2929

3030
# compute marginals
31-
marginals2 = marginals(tn)
32-
npass = 0
33-
for i in 1:(problem.nvars)
34-
npass += length(marginals2[i]) == 1 || isapprox(marginals2[i], problem.reference_marginals[i]; atol = 1e-6)
35-
end
36-
@test npass == problem.nvars
31+
ti_sol = marginals(tn)
32+
ref_sol[problem.obsvars] .= fill([1.0], length(problem.obsvars)) # imitate dummy vars
33+
@test isapprox(ti_sol, ref_sol; atol = 1e-5)
34+
end
35+
36+
function get_problems(problem_set::String)
37+
# Capture the problem names that belong to the current problem_set
38+
regex = Regex("($(problem_set)_\\d*)(\\.uai)\$")
39+
return readdir(artifact"MAR_prob"; sort = false) |>
40+
x -> map(y -> match(regex, y), x) |> # apply regex
41+
x -> filter(!isnothing, x) |> # filter out `nothing` values
42+
x -> map(first, x) # get the first capture of each element
3743
end
3844

3945
@testset "gradient-based tensor network solvers" begin
40-
@testset "UAI 2014 problem set" begin
41-
benchmarks = [
42-
#("Alchemy", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
43-
#("CSP", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
44-
#("DBN", KaHyParBipartite(sc_target=25)),
45-
#("Grids", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), # greedy also works
46-
#("linkage", TreeSA(ntrials=3, niters=20, βs=0.1:0.1:40)), # linkage_15 fails
47-
#("ObjectDetection", TreeSA(ntrials=1, niters=5, βs=1:0.1:100)), # ObjectDetection_35 fails
48-
#("Pedigree", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), # greedy also works
49-
("Promedus", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
50-
#("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
51-
("Segmentation", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)) # greedy also works
52-
]
53-
#benchmarks = [("relational", fill(1.0, 5), TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100))]
54-
#benchmarks = [("DBN",fill(1.0, 6), SABipartite(sc_target=25, βs=0.1:0.01:50))]
55-
for (benchmark, optimizer) in benchmarks
56-
@testset "$(benchmark) benchmark" begin
46+
problem_sets = [
47+
#("Alchemy", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
48+
#("CSP", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
49+
#("DBN", KaHyParBipartite(sc_target = 25)),
50+
#("Grids", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
51+
#("linkage", TreeSA(ntrials = 3, niters = 20, βs = 0.1:0.1:40)), # linkage_15 fails
52+
#("ObjectDetection", TreeSA(ntrials = 1, niters = 5, βs = 1:0.1:100)),
53+
#("Pedigree", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
54+
("Promedus", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
55+
#("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
56+
("Segmentation", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)) # greedy also works
57+
]
58+
59+
for (problem_set, optimizer) in problem_sets
60+
@testset "$(problem_set) problem_set" begin
5761

58-
# Capture the problem names that belong to the current benchmark
59-
rexp = Regex("($(benchmark)_\\d*)(\\.uai)\$")
60-
problems = readdir(artifact"MAR_prob"; sort = false) |>
61-
x -> map(y -> match(rexp, y), x) |> # apply regex
62-
x -> filter(!isnothing, x) |> # filter out `nothing` values
63-
x -> map(first, x) # get the first capture of each element
62+
# Capture the problem names that belong to the current problem set
63+
problems = get_problems(problem_set)
6464

65-
for problem in problems
66-
@info "Testing: $problem"
67-
@testset "$(problem)" begin
68-
problem = read_uai_problem(problem)
65+
for problem in problems
66+
@info "Testing: $problem"
67+
@testset "$(problem)" begin
68+
problem = read_uai_problem(problem)
69+
ref_sol = problem.reference_marginals
70+
obsvars = problem.obsvars
6971

70-
# does not optimize over open vertices
71-
tn = TensorNetworkModel(problem; optimizer)
72-
sc = contraction_complexity(tn).sc
73-
if sc > 28
74-
error("space complexity too large! got $(sc)")
75-
end
76-
@debug contraction_complexity(tn)
77-
marginals2 = marginals(tn)
78-
# for dangling vertices, the output size is 1.
79-
npass = 0
80-
for i in 1:(problem.nvars)
81-
npass += length(marginals2[i]) == 1 || isapprox(marginals2[i], problem.reference_marginals[i]; atol = 1e-6)
82-
end
83-
@test npass == problem.nvars
84-
end
72+
# does not optimize over open vertices
73+
tn = TensorNetworkModel(problem; optimizer)
74+
sc = contraction_complexity(tn).sc
75+
sc > 28 && error("space complexity too large! got $(sc)")
76+
@debug contraction_complexity(tn)
77+
ti_sol = marginals(tn)
78+
ref_sol[obsvars] .= fill([1.0], length(obsvars)) # imitate dummy vars
79+
@test isapprox(ti_sol, ref_sol; atol = 1e-4)
8580
end
8681
end
8782
end

0 commit comments

Comments
 (0)