Skip to content

Commit 7dd6ebc

Browse files
authored
Merge pull request #48 from TensorBFS/jg/revert-tropicalgemm
remove tropicalgemm as a dependency
2 parents 3f43238 + cacfee2 commit 7dd6ebc

File tree

12 files changed

+47
-17
lines changed

12 files changed

+47
-17
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1111
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1212
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1313
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
14-
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
1514
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
1615

1716
[compat]
@@ -21,6 +20,5 @@ OMEinsum = "0.7"
2120
PrecompileTools = "1"
2221
Requires = "1"
2322
StatsBase = "0.34"
24-
TropicalGEMM = "0.1"
2523
TropicalNumbers = "0.5.4"
2624
julia = "1.3"

docs/src/api/public.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,5 @@ read_solution_file
6262
read_td_file
6363
sample
6464
set_evidence!
65+
set_query!
6566
```

examples/asia/main.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ logp, cfg = most_probable_config(tn)
107107
# Compute the most probable values of certain variables (e.g., 4 and 7) while
108108
# marginalizing over others. This is known as Maximum a Posteriori (MAP)
109109
# estimation.
110-
mmap = MMAPModel(instance; queryvars = [4, 7])
110+
set_query!(instance, [4, 7])
111+
mmap = MMAPModel(instance)
111112

112113
# ---
113114

src/Core.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct UAIInstance{ET, FT <: Factor{ET}}
3737
obsvars::Vector{Int}
3838
obsvals::Vector{Int}
3939
queryvars::Vector{Int}
40-
reference_solution::Union{Vector{Vector{ET}}, Vector{Int}, Float64}
40+
reference_solution
4141
end
4242

4343
Base.show(io::IO, ::MIME"text/plain", uai::UAIInstance) = Base.show(io, uai)
@@ -69,6 +69,17 @@ function set_evidence!(uai::UAIInstance, pairs::Pair{Int}...)
6969
return uai
7070
end
7171

72+
"""
73+
$TYPEDSIGNATURES
74+
75+
Set the query variables of an UAI instance.
76+
"""
77+
function set_query!(uai::UAIInstance, vars::AbstractVector{Int})
78+
empty!(uai.queryvars)
79+
append!(uai.queryvars, vars)
80+
return uai
81+
end
82+
7283
"""
7384
$(TYPEDEF)
7485
@@ -122,6 +133,9 @@ function TensorNetworkModel(
122133
optimizer = GreedyMethod(),
123134
simplifier = nothing
124135
)::TensorNetworkModel
136+
if !isempty(instance.queryvars)
137+
@warn "The `queryvars` field of the input `UAIInstance` instance is designed for the `MMAPModel`, which is not respected by `TensorNetworkModel`. Got non-empty value: $(uai.queryvars)"
138+
end
125139
return TensorNetworkModel(
126140
1:(instance.nvars),
127141
instance.cards,

src/TensorInference.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ module TensorInference
1010
using OMEinsum, LinearAlgebra
1111
using DocStringExtensions, TropicalNumbers
1212
# The Tropical GEMM support
13-
using TropicalGEMM
1413
using StatsBase
1514

1615
# reexport OMEinsum functions
@@ -19,7 +18,7 @@ export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABiparti
1918

2019
# read and load uai files
2120
export read_model_file, read_td_file, read_evidence_file, read_solution_file, read_instance, UAIInstance
22-
export set_evidence!
21+
export set_evidence!, set_query!
2322

2423
# marginals
2524
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

src/mmap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ end
5858
"""
5959
$(TYPEDSIGNATURES)
6060
"""
61-
function MMAPModel(instance::UAIInstance; queryvars, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
61+
function MMAPModel(instance::UAIInstance; openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
6262
return MMAPModel(
63-
1:(instance.nvars), instance.cards, instance.factors; queryvars, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
63+
1:(instance.nvars), instance.cards, instance.factors; queryvars=instance.queryvars, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
6464
)
6565
end
6666

src/sampling.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix
128128
return samples.samples
129129
end
130130

131+
function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
132+
# slicing is not supported yet.
133+
if length(se.slicing) != 0
134+
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
135+
end
136+
return generate_samples(se.eins, cache, samples, size_dict)
137+
end
131138
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
132139
if !OMEinsum.isleaf(code)
133140
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ end
246246

247247
function read_instance_from_string(uai::AbstractString; eltype = Float64)::UAIInstance
248248
nvars, cards, ncliques, factors = read_model_string(uai; factor_eltype = eltype)
249-
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Vector{eltype}[])
249+
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Int[], nothing)
250250
end
251251

252252
# patch to get content by broadcasting into array, while keep array size unchanged.

test/cuda.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,25 @@ end
4242
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
4343
tn_ref = TensorNetworkModel(instance; optimizer)
4444
# does not marginalize any var
45-
tn = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
45+
set_query!(instance, collect(1:instance.nvars))
46+
tn = MMAPModel(instance; optimizer)
4647
r1, r2 = maximum_logp(tn_ref; usecuda = true), maximum_logp(tn; usecuda = true)
4748
@test r1 isa CuArray
4849
@test r2 isa CuArray
4950
@test r1 r2
5051

5152
# marginalize all vars
52-
tn2 = MMAPModel(instance; queryvars = Int[], optimizer)
53+
set_query!(instance, Int[])
54+
tn2 = MMAPModel(instance; optimizer)
5355
cup = probability(tn_ref; usecuda = true)
5456
culogp = maximum_logp(tn2; usecuda = true)
5557
@test cup isa RescaledArray{T, N, <:CuArray} where {T, N}
5658
@test culogp isa CuArray
5759
@test Array(cup)[] exp(Array(culogp)[])
5860

5961
# does not optimize over open vertices
60-
tn3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
62+
set_query!(instance, setdiff(1:instance.nvars, [2, 4, 6]))
63+
tn3 = MMAPModel(instance; optimizer)
6164
logp, config = most_probable_config(tn3; usecuda = true)
6265
@test log_probability(tn3, config) logp
6366
end

test/mmap.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@ end
1515
tn_ref = TensorNetworkModel(instance; optimizer)
1616

1717
# Does not marginalize any var
18-
mmap = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
18+
set_query!(instance, collect(1:instance.nvars))
19+
mmap = MMAPModel(instance; optimizer)
1920
@debug(mmap)
2021
@test maximum_logp(tn_ref) maximum_logp(mmap)
2122

2223
# Marginalize all vars
23-
mmap2 = MMAPModel(instance; queryvars = Int[], optimizer)
24+
set_query!(instance, Int[])
25+
mmap2 = MMAPModel(instance; optimizer)
2426
@debug(mmap2)
2527
@test Array(probability(tn_ref))[] exp(maximum_logp(mmap2)[])
2628

2729
# Does not optimize over open vertices
28-
mmap3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
30+
set_query!(instance, setdiff(1:instance.nvars, [2, 4, 6]))
31+
mmap3 = MMAPModel(instance; optimizer)
2932
@debug(mmap3)
3033
logp, config = most_probable_config(mmap3)
3134
@test log_probability(mmap3, config) logp
@@ -42,7 +45,7 @@ end
4245
@info "Testing: $problem_name"
4346
model_filepath, evidence_filepath, query_filepath, solution_filepath = get_instance_filepaths(problem_name, "MMAP")
4447
instance = read_instance(model_filepath; evidence_filepath, query_filepath, solution_filepath)
45-
model = MMAPModel(instance; queryvars = instance.queryvars, optimizer)
48+
model = MMAPModel(instance; optimizer)
4649
_, solution = most_probable_config(model)
4750
@test solution == instance.reference_solution
4851
end

0 commit comments

Comments
 (0)