Skip to content

Commit ca954d6

Browse files
committed
refactor UAI file reading
1 parent 79797ea commit ca954d6

File tree

15 files changed

+232
-301
lines changed

15 files changed

+232
-301
lines changed

docs/src/api/public.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ TreeSA
4141
MMAPModel
4242
RescaledArray
4343
TensorNetworkModel
44+
ArtifactProblemSpec
4445
UAIInstance
4546
```
4647

@@ -55,13 +56,15 @@ marginals
5556
maximum_logp
5657
most_probable_config
5758
probability
58-
read_evidence_file
59+
dataset_from_artifact
60+
problem_from_artifact
5961
read_instance
60-
read_instance_from_artifact
61-
read_model_file
62+
read_evidence
63+
read_solution
64+
read_queryvars
65+
read_instance_from_file
66+
read_evidence_file
6267
read_solution_file
6368
read_td_file
6469
sample
65-
set_evidence!
66-
set_query!
6770
```

examples/asia/main.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,9 @@ get_vars(tn)
7676
# ---
7777

7878
# Set an evidence: Assume that the "X-ray" result (variable 7) is positive.
79-
set_evidence!(instance, 7 => 0)
80-
81-
# ---
82-
8379
# Since setting an evidence may affect the contraction order of the tensor
8480
# network, recompute it.
85-
tn = TensorNetworkModel(instance)
81+
tn = TensorNetworkModel(instance, evidence=Dict(7=>0))
8682

8783
# ---
8884

@@ -107,8 +103,7 @@ logp, cfg = most_probable_config(tn)
107103
# Compute the most probable values of certain variables (e.g., 4 and 7) while
108104
# marginalizing over others. This is known as Maximum a Posteriori (MAP)
109105
# estimation.
110-
set_query!(instance, [4, 7])
111-
mmap = MMAPModel(instance)
106+
mmap = MMAPModel(instance, queryvars=[4,7])
112107

113108
# ---
114109

src/Core.jl

Lines changed: 24 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,62 +22,23 @@ $(TYPEDEF)
2222
* `nclique` is the number of cliques,
2323
* `cards` is a vector of cardinalities for variables,
2424
* `factors` is a vector of factors,
25-
26-
* `obsvars` is a vector of observed variables,
27-
* `obsvals` is a vector of observed values,
28-
* `queryvars` is a vector of query variables,
29-
* `reference_solution` is a vector with the reference solution.
3025
"""
3126
struct UAIInstance{ET, FT <: Factor{ET}}
3227
nvars::Int
3328
nclique::Int
3429
cards::Vector{Int}
3530
factors::Vector{FT}
36-
37-
obsvars::Vector{Int}
38-
obsvals::Vector{Int}
39-
queryvars::Vector{Int}
40-
reference_solution
4131
end
4232

4333
Base.show(io::IO, ::MIME"text/plain", uai::UAIInstance) = Base.show(io, uai)
4434
function Base.show(io::IO, uai::UAIInstance)
4535
println(io, "UAIInstance(nvars = $(uai.nvars), nclique = $(uai.nclique))")
4636
println(io, " variables :")
47-
for (var, card) in zip(1:uai.nvars, uai.cards)
48-
println(io, string_var(" $var of size $card", var uai.queryvars, Dict(zip(uai.obsvars, uai.obsvals))))
49-
end
5037
println(io, " factors : ")
51-
for f in uai.factors
52-
println(io, " $(summary(f))")
38+
for (k, f) in enumerate(uai.factors)
39+
print(io, " $(summary(f))")
40+
k == length(uai.factors) || println(io)
5341
end
54-
print(io, " reference_solution : $(uai.reference_solution)")
55-
end
56-
57-
"""
58-
$TYPEDSIGNATURES
59-
60-
Set the evidence of an UAI instance.
61-
"""
62-
function set_evidence!(uai::UAIInstance, pairs::Pair{Int}...)
63-
empty!(uai.obsvars)
64-
empty!(uai.obsvals)
65-
for (var, val) in pairs
66-
push!(uai.obsvars, var)
67-
push!(uai.obsvals, val)
68-
end
69-
return uai
70-
end
71-
72-
"""
73-
$TYPEDSIGNATURES
74-
75-
Set the query variables of an UAI instance.
76-
"""
77-
function set_query!(uai::UAIInstance, vars::AbstractVector{Int})
78-
empty!(uai.queryvars)
79-
append!(uai.queryvars, vars)
80-
return uai
8142
end
8243

8344
"""
@@ -89,32 +50,32 @@ Probabilistic modeling with a tensor network.
8950
* `vars` is the degree of freedoms in the tensor network.
9051
* `code` is the tensor network contraction pattern.
9152
* `tensors` is the tensors fed into the tensor network.
92-
* `fixedvertices` is a dictionary to specifiy degree of freedoms fixed to certain values.
53+
* `evidence` is a dictionary to specifiy degree of freedoms fixed to certain values.
9354
"""
9455
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
9556
vars::Vector{LT}
9657
code::ET
9758
tensors::Vector{MT}
98-
fixedvertices::Dict{LT, Int}
59+
evidence::Dict{LT, Int}
9960
end
10061

10162
function Base.show(io::IO, tn::TensorNetworkModel)
10263
open = getiyv(tn.code)
103-
variables = join([string_var(var, open, tn.fixedvertices) for var in tn.vars], ", ")
64+
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")
10465
tc, sc, rw = contraction_complexity(tn)
10566
println(io, "$(typeof(tn))")
10667
println(io, "variables: $variables")
10768
print_tcscrw(io, tc, sc, rw)
10869
end
10970
Base.show(io::IO, ::MIME"text/plain", tn::TensorNetworkModel) = Base.show(io, tn)
11071

111-
function string_var(var, open, fixedvertices)
112-
if var open && haskey(fixedvertices, var)
113-
"$var (open, fixed to $(fixedvertices[var]))"
72+
function string_var(var, open, evidence)
73+
if var open && haskey(evidence, var)
74+
"$var (open, fixed to $(evidence[var]))"
11475
elseif var open
11576
"$var (open)"
116-
elseif haskey(fixedvertices, var)
117-
"$var (evidence → $(fixedvertices[var]))"
77+
elseif haskey(evidence, var)
78+
"$var (evidence → $(evidence[var]))"
11879
else
11980
"$var"
12081
end
@@ -129,19 +90,17 @@ $(TYPEDSIGNATURES)
12990
"""
13091
function TensorNetworkModel(
13192
instance::UAIInstance;
132-
openvertices = (),
93+
openvars = (),
94+
evidence = Dict{Int,Int}(),
13395
optimizer = GreedyMethod(),
13496
simplifier = nothing
13597
)::TensorNetworkModel
136-
if !isempty(instance.queryvars)
137-
@warn "The `queryvars` field of the input `UAIInstance` instance is designed for the `MMAPModel`, which is not respected by `TensorNetworkModel`. Got non-empty value: $(uai.queryvars)"
138-
end
13998
return TensorNetworkModel(
14099
1:(instance.nvars),
141100
instance.cards,
142101
instance.factors;
143-
openvertices,
144-
fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)),
102+
openvars,
103+
evidence,
145104
optimizer,
146105
simplifier
147106
)
@@ -154,18 +113,18 @@ function TensorNetworkModel(
154113
vars::AbstractVector{LT},
155114
cards::AbstractVector{Int},
156115
factors::Vector{<:Factor{T}};
157-
openvertices = (),
158-
fixedvertices = Dict{LT, Int}(),
116+
openvars = (),
117+
evidence = Dict{LT, Int}(),
159118
optimizer = GreedyMethod(),
160119
simplifier = nothing
161120
)::TensorNetworkModel where {T, LT}
162121
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
163122
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
164123
# e.g.
165124
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
166-
rawcode = EinCode([[[var] for var in vars]..., [[factor.vars...] for factor in factors]...], collect(LT, openvertices)) # labels for vertex tensors (unity tensors) and edge tensors
125+
rawcode = EinCode([[[var] for var in vars]..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
167126
tensors = Array{T}[[ones(T, cards[i]) for i in 1:length(vars)]..., [t.vals for t in factors]...]
168-
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; fixedvertices, optimizer, simplifier)
127+
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier)
169128
end
170129

171130
"""
@@ -175,7 +134,7 @@ function TensorNetworkModel(
175134
vars::AbstractVector{LT},
176135
rawcode::EinCode,
177136
tensors::Vector{<:AbstractArray};
178-
fixedvertices = Dict{LT, Int}(),
137+
evidence = Dict{LT, Int}(),
179138
optimizer = GreedyMethod(),
180139
simplifier = nothing
181140
)::TensorNetworkModel where {LT}
@@ -185,7 +144,7 @@ function TensorNetworkModel(
185144
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
186145
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
187146
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
188-
TensorNetworkModel(collect(LT, vars), code, tensors, fixedvertices)
147+
TensorNetworkModel(collect(LT, vars), code, tensors, evidence)
189148
end
190149

191150
"""
@@ -202,10 +161,10 @@ Get the cardinalities of variables in this tensor network.
202161
"""
203162
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
204163
vars = get_vars(tn)
205-
[fixedisone && haskey(tn.fixedvertices, vars[k]) ? 1 : length(tn.tensors[k]) for k in 1:length(vars)]
164+
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in 1:length(vars)]
206165
end
207166

208-
chfixedvertices(tn::TensorNetworkModel, fixedvertices) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, fixedvertices)
167+
chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
209168

210169
"""
211170
$(TYPEDSIGNATURES)
@@ -221,7 +180,7 @@ end
221180
$(TYPEDSIGNATURES)
222181
223182
Contract the tensor network and return a probability array with its rank specified in the contraction code `tn.code`.
224-
The returned array may not be l1-normalized even if the total probability is l1-normalized, because the evidence `tn.fixedvertices` may not be empty.
183+
The returned array may not be l1-normalized even if the total probability is l1-normalized, because the evidence `tn.evidence` may not be empty.
225184
"""
226185
function probability(tn::TensorNetworkModel; usecuda = false, rescale = true)::AbstractArray
227186
return tn.code(adapt_tensors(tn; usecuda, rescale)...)

src/TensorInference.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ export RescaledArray
1818
export contraction_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors
1919

2020
# read and load uai files
21-
export read_model_file, read_td_file, read_evidence_file, read_solution_file
22-
export read_instance, read_instance_from_artifact, UAIInstance
23-
export set_evidence!, set_query!
21+
export read_instance_from_file, read_td_file, read_evidence_file, read_solution_file
22+
export problem_from_artifact, ArtifactProblemSpec
23+
export read_instance, UAIInstance, read_evidence, read_solution, read_queryvars, dataset_from_artifact
2424

2525
# marginals
2626
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals
@@ -47,13 +47,13 @@ function __init__()
4747
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
4848
end
4949

50-
import PrecompileTools
51-
PrecompileTools.@setup_workload begin
52-
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
53-
# precompile file and potentially make loading faster.
54-
#PrecompileTools.@compile_workload begin
55-
# include("../example/asia/main.jl")
56-
#end
57-
end
50+
# import PrecompileTools
51+
# PrecompileTools.@setup_workload begin
52+
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
53+
# # precompile file and potentially make loading faster.
54+
# PrecompileTools.@compile_workload begin
55+
# include("../example/asia/main.jl")
56+
# end
57+
# end
5858

5959
end # module

src/map.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Re
4949
tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false))
5050
logp, grads = cost_and_gradient(tn.code, tensors)
5151
# use Array to convert CuArray to CPU arrays
52-
return content(Array(logp)[]), map(k -> haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
52+
return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
5353
end
5454

5555
"""

src/mar.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# generate tensors based on which vertices are fixed.
2-
adapt_tensors(gp::TensorNetworkModel; usecuda, rescale) = adapt_tensors(gp.code, gp.tensors, gp.fixedvertices; usecuda, rescale)
3-
function adapt_tensors(code, tensors, fixedvertices; usecuda, rescale)
2+
adapt_tensors(gp::TensorNetworkModel; usecuda, rescale) = adapt_tensors(gp.code, gp.tensors, gp.evidence; usecuda, rescale)
3+
function adapt_tensors(code, tensors, evidence; usecuda, rescale)
44
ixs = getixsv(code)
55
# `ix` is the vector of labels (or a degree of freedoms) for a tensor,
66
# if a label in `ix` is fixed to a value, do the slicing to the tensor it associates to.
77
map(tensors, ixs) do t, ix
8-
dims = map(ixi -> ixi keys(fixedvertices) ? Colon() : ((fixedvertices[ixi] + 1):(fixedvertices[ixi] + 1)), ix)
8+
dims = map(ixi -> ixi keys(evidence) ? Colon() : ((evidence[ixi] + 1):(evidence[ixi] + 1)), ix)
99
t2 = t[dims...]
1010
t3 = usecuda ? CuArray(t2) : t2
1111
rescale ? rescale_array(t3) : t3

0 commit comments

Comments
 (0)