@@ -9,11 +9,11 @@ using TensorInference
99 op = ein " ij, j -> i"
1010 @test Array (x) ≈ exp (2.0 ) .* [2.0 , 3.0 ]
1111 @test op (Array (A), Array (x)) ≈ Array (op (A, x))
12- println (x)
1312end
1413
1514@testset " cached, rescaled contract" begin
1615 problem = read_uai_problem (" Promedus_14" )
16+ ref_sol = problem. reference_marginals
1717 optimizer = TreeSA (ntrials = 1 , niters = 5 , βs = 0.1 : 0.1 : 100 )
1818 tn = TensorNetworkModel (problem; optimizer)
1919 p1 = probability (tn; usecuda = false , rescale = false )
2828 @test Array (cache. content) ≈ p1
2929
3030 # compute marginals
31- marginals2 = marginals (tn)
32- npass = 0
33- for i in 1 : (problem. nvars)
34- npass += length (marginals2[i]) == 1 || isapprox (marginals2[i], problem. reference_marginals[i]; atol = 1e-6 )
35- end
36- @test npass == problem. nvars
31+ ti_sol = marginals (tn)
32+ ref_sol[problem. obsvars] .= fill ([1.0 ], length (problem. obsvars)) # imitate dummy vars
33+ @test isapprox (ti_sol, ref_sol; atol = 1e-5 )
34+ end
35+
36+ function get_problems (problem_set:: String )
37+ # Capture the problem names that belong to the current problem_set
38+ regex = Regex (" ($(problem_set) _\\ d*)(\\ .uai)\$ " )
39+ return readdir (artifact " MAR_prob" ; sort = false ) |>
40+ x -> map (y -> match (regex, y), x) |> # apply regex
41+ x -> filter (! isnothing, x) |> # filter out `nothing` values
42+ x -> map (first, x) # get the first capture of each element
3743end
3844
3945@testset " gradient-based tensor network solvers" begin
40- @testset " UAI 2014 problem set" begin
41- benchmarks = [
42- # ("Alchemy", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
43- # ("CSP", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
44- # ("DBN", KaHyParBipartite(sc_target=25)),
45- # ("Grids", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), # greedy also works
46- # ("linkage", TreeSA(ntrials=3, niters=20, βs=0.1:0.1:40)), # linkage_15 fails
47- # ("ObjectDetection", TreeSA(ntrials=1, niters=5, βs=1:0.1:100)), # ObjectDetection_35 fails
48- # ("Pedigree", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), # greedy also works
49- (" Promedus" , TreeSA (ntrials = 1 , niters = 5 , βs = 0.1 : 0.1 : 100 )), # greedy also works
50- # ("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
51- (" Segmentation" , TreeSA (ntrials = 1 , niters = 5 , βs = 0.1 : 0.1 : 100 )) # greedy also works
52- ]
53- # benchmarks = [("relational", fill(1.0, 5), TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100))]
54- # benchmarks = [("DBN",fill(1.0, 6), SABipartite(sc_target=25, βs=0.1:0.01:50))]
55- for (benchmark, optimizer) in benchmarks
56- @testset " $(benchmark) benchmark" begin
46+ problem_sets = [
47+ # ("Alchemy", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
48+ # ("CSP", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
49+ # ("DBN", KaHyParBipartite(sc_target = 25)),
50+ # ("Grids", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
51+ # ("linkage", TreeSA(ntrials = 3, niters = 20, βs = 0.1:0.1:40)), # linkage_15 fails
52+ # ("ObjectDetection", TreeSA(ntrials = 1, niters = 5, βs = 1:0.1:100)),
53+ # ("Pedigree", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
54+ (" Promedus" , TreeSA (ntrials = 1 , niters = 5 , βs = 0.1 : 0.1 : 100 )), # greedy also works
55+ # ("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
56+ (" Segmentation" , TreeSA (ntrials = 1 , niters = 5 , βs = 0.1 : 0.1 : 100 )) # greedy also works
57+ ]
58+
59+ for (problem_set, optimizer) in problem_sets
60+ @testset " $(problem_set) problem_set" begin
5761
58- # Capture the problem names that belong to the current benchmark
59- rexp = Regex (" ($(benchmark) _\\ d*)(\\ .uai)\$ " )
60- problems = readdir (artifact " MAR_prob" ; sort = false ) |>
61- x -> map (y -> match (rexp, y), x) |> # apply regex
62- x -> filter (! isnothing, x) |> # filter out `nothing` values
63- x -> map (first, x) # get the first capture of each element
62+ # Capture the problem names that belong to the current problem set
63+ problems = get_problems (problem_set)
6464
65- for problem in problems
66- @info " Testing: $problem "
67- @testset " $(problem) " begin
68- problem = read_uai_problem (problem)
65+ for problem in problems
66+ @info " Testing: $problem "
67+ @testset " $(problem) " begin
68+ problem = read_uai_problem (problem)
69+ ref_sol = problem. reference_marginals
70+ obsvars = problem. obsvars
6971
70- # does not optimize over open vertices
71- tn = TensorNetworkModel (problem; optimizer)
72- sc = contraction_complexity (tn). sc
73- if sc > 28
74- error (" space complexity too large! got $(sc) " )
75- end
76- @debug contraction_complexity (tn)
77- marginals2 = marginals (tn)
78- # for dangling vertices, the output size is 1.
79- npass = 0
80- for i in 1 : (problem. nvars)
81- npass += length (marginals2[i]) == 1 || isapprox (marginals2[i], problem. reference_marginals[i]; atol = 1e-6 )
82- end
83- @test npass == problem. nvars
84- end
72+ # does not optimize over open vertices
73+ tn = TensorNetworkModel (problem; optimizer)
74+ sc = contraction_complexity (tn). sc
75+ sc > 28 && error (" space complexity too large! got $(sc) " )
76+ @debug contraction_complexity (tn)
77+ ti_sol = marginals (tn)
78+ ref_sol[obsvars] .= fill ([1.0 ], length (obsvars)) # imitate dummy vars
79+ @test isapprox (ti_sol, ref_sol; atol = 1e-4 )
8580 end
8681 end
8782 end
0 commit comments