Skip to content

Commit 69a6ccc

Browse files
committed
fix tests
1 parent 5812f71 commit 69a6ccc

File tree

7 files changed

+36
-12
lines changed

7 files changed

+36
-12
lines changed

src/Core.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct UAIInstance{ET, FT <: Factor{ET}}
3737
obsvars::Vector{Int}
3838
obsvals::Vector{Int}
3939
queryvars::Vector{Int}
40-
reference_solution::Union{Vector{Vector{ET}}, Vector{Int}, Float64}
40+
reference_solution
4141
end
4242

4343
Base.show(io::IO, ::MIME"text/plain", uai::UAIInstance) = Base.show(io, uai)
@@ -69,6 +69,17 @@ function set_evidence!(uai::UAIInstance, pairs::Pair{Int}...)
6969
return uai
7070
end
7171

72+
"""
73+
$TYPEDSIGNATURES
74+
75+
Set the query variables of an UAI instance.
76+
"""
77+
function set_query!(uai::UAIInstance, vars::AbstractVector{Int})
78+
empty!(uai.queryvars)
79+
append!(uai.queryvars, vars)
80+
return uai
81+
end
82+
7283
"""
7384
$(TYPEDEF)
7485
@@ -122,6 +133,9 @@ function TensorNetworkModel(
122133
optimizer = GreedyMethod(),
123134
simplifier = nothing
124135
)::TensorNetworkModel
136+
if !isempty(instance.queryvars)
137+
@warn "The `queryvars` field of the input `UAIInstance` instance is designed for the `MMAPModel`, which is not respected by `TensorNetworkModel`. Got non-empty value: $(uai.queryvars)"
138+
end
125139
return TensorNetworkModel(
126140
1:(instance.nvars),
127141
instance.cards,

src/TensorInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABiparti
1818

1919
# read and load uai files
2020
export read_model_file, read_td_file, read_evidence_file, read_solution_file, read_instance, UAIInstance
21-
export set_evidence!
21+
export set_evidence!, set_query!
2222

2323
# marginals
2424
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

src/mmap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ end
5858
"""
5959
$(TYPEDSIGNATURES)
6060
"""
61-
function MMAPModel(instance::UAIInstance; queryvars, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
61+
function MMAPModel(instance::UAIInstance; openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
6262
return MMAPModel(
63-
1:(instance.nvars), instance.cards, instance.factors; queryvars, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
63+
1:(instance.nvars), instance.cards, instance.factors; queryvars=instance.queryvars, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
6464
)
6565
end
6666

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ end
246246

247247
function read_instance_from_string(uai::AbstractString; eltype = Float64)::UAIInstance
248248
nvars, cards, ncliques, factors = read_model_string(uai; factor_eltype = eltype)
249-
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Vector{eltype}[])
249+
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Int[], nothing)
250250
end
251251

252252
# patch to get content by broadcasting into array, while keep array size unchanged.

test/cuda.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,25 @@ end
4242
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
4343
tn_ref = TensorNetworkModel(instance; optimizer)
4444
# does not marginalize any var
45-
tn = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
45+
set_query!(instance, collect(1:instance.nvars))
46+
tn = MMAPModel(instance; optimizer)
4647
r1, r2 = maximum_logp(tn_ref; usecuda = true), maximum_logp(tn; usecuda = true)
4748
@test r1 isa CuArray
4849
@test r2 isa CuArray
4950
@test r1 r2
5051

5152
# marginalize all vars
52-
tn2 = MMAPModel(instance; queryvars = Int[], optimizer)
53+
set_query!(instance, Int[])
54+
tn2 = MMAPModel(instance; optimizer)
5355
cup = probability(tn_ref; usecuda = true)
5456
culogp = maximum_logp(tn2; usecuda = true)
5557
@test cup isa RescaledArray{T, N, <:CuArray} where {T, N}
5658
@test culogp isa CuArray
5759
@test Array(cup)[] exp(Array(culogp)[])
5860

5961
# does not optimize over open vertices
60-
tn3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
62+
set_query!(instance, setdiff(1:instance.nvars, [2, 4, 6]))
63+
tn3 = MMAPModel(instance; optimizer)
6164
logp, config = most_probable_config(tn3; usecuda = true)
6265
@test log_probability(tn3, config) logp
6366
end

test/mmap.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@ end
1515
tn_ref = TensorNetworkModel(instance; optimizer)
1616

1717
# Does not marginalize any var
18-
mmap = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
18+
set_query!(instance, collect(1:instance.nvars))
19+
mmap = MMAPModel(instance; optimizer)
1920
@debug(mmap)
2021
@test maximum_logp(tn_ref) maximum_logp(mmap)
2122

2223
# Marginalize all vars
23-
mmap2 = MMAPModel(instance; queryvars = Int[], optimizer)
24+
set_query!(instance, Int[])
25+
mmap2 = MMAPModel(instance; optimizer)
2426
@debug(mmap2)
2527
@test Array(probability(tn_ref))[] exp(maximum_logp(mmap2)[])
2628

2729
# Does not optimize over open vertices
28-
mmap3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
30+
set_query!(instance, setdiff(1:instance.nvars, [2, 4, 6]))
31+
mmap3 = MMAPModel(instance; optimizer)
2932
@debug(mmap3)
3033
logp, config = most_probable_config(mmap3)
3134
@test log_probability(mmap3, config) logp
@@ -42,7 +45,7 @@ end
4245
@info "Testing: $problem_name"
4346
model_filepath, evidence_filepath, query_filepath, solution_filepath = get_instance_filepaths(problem_name, "MMAP")
4447
instance = read_instance(model_filepath; evidence_filepath, query_filepath, solution_filepath)
45-
model = MMAPModel(instance; queryvars = instance.queryvars, optimizer)
48+
model = MMAPModel(instance; optimizer)
4649
_, solution = most_probable_config(model)
4750
@test solution == instance.reference_solution
4851
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ end
2121
include("pr.jl")
2222
end
2323

24+
@testset "sampling" begin
25+
include("sampling.jl")
26+
end
27+
2428
using CUDA
2529
if CUDA.functional()
2630
include("cuda.jl")

0 commit comments

Comments
 (0)