Skip to content

Commit 4d5b452

Browse files
authored
Merge pull request #61 from TensorBFS/jg/update-marginals
Let marginals return dict
2 parents a066560 + 6266d7c commit 4d5b452

File tree

8 files changed

+89
-24
lines changed

8 files changed

+89
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorInference"
22
uuid = "c2297e78-99bd-40ad-871d-f50e56b81012"
33
authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
4-
version = "0.3.0"
4+
version = "0.4.0"
55

66
[deps]
77
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

examples/hard-core-lattice-gas/main.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ partition_func[]
5555

5656
# The marginal probabilities can be computed with the [`marginals`](@ref) function, which measures how likely a site is occupied.
5757
mars = marginals(pmodel)
58-
show_graph(graph; locs=sites, vertex_colors=[(1-b, 1-b, 1-b) for b in getindex.(mars, 2)], texts=fill("", nv(graph)))
58+
show_graph(graph; locs=sites, vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph)))
5959
# The can see the sites at the corner is more likely to be occupied.
6060
# To obtain two-site correlations, one can set the variables to query marginal probabilities manually.
6161
pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)])
6262
mars = marginals(pmodel2);
6363

6464
# We show the probability that both sites on an edge are not occupied
65-
show_graph(graph; locs=sites, edge_colors=[(b=mar[1, 1]; (1-b, 1-b, 1-b)) for mar in mars], texts=fill("", nv(graph)), edge_line_width=5)
65+
show_graph(graph; locs=sites, edge_colors=[(b = mars[[e.src, e.dst]][1, 1]; (1-b, 1-b, 1-b)) for e in edges(graph)], texts=fill("", nv(graph)), edge_line_width=5)
6666

6767
# ## The most likely configuration
6868
# The MAP and MMAP can be used to get the most likely configuration given an evidence.
@@ -91,4 +91,4 @@ sum(config2)
9191
# The return value is a matrix, with the columns correspond to different samples.
9292
configs = sample(pmodel3, 1000)
9393
sizes = sum(configs; dims=1)
94-
[count(==(i), sizes) for i=0:34] # counting sizes
94+
[count(==(i), sizes) for i=0:34] # counting sizes

src/mar.jl

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,77 @@ end
124124
"""
125125
$(TYPEDSIGNATURES)
126126
127-
Returns the marginal probability distribution of variables.
128-
One can use `get_vars(tn)` to get the full list of variables in this tensor network.
127+
Queries the marginals of the variables in a [`TensorNetworkModel`](@ref). The
128+
function returns a dictionary, where the keys are the variables and the values
129+
are their respective marginals. A marginal is a probability distribution over
130+
a subset of variables, obtained by integrating or summing over the remaining
131+
variables in the model. By default, the function returns the marginals of all
132+
individual variables. To specify which marginal variables to query, set the
133+
`mars` field when constructing a [`TensorNetworkModel`](@ref). Note that
134+
the choice of marginal variables will affect the contraction order of the
135+
tensor network.
136+
137+
### Arguments
138+
- `tn`: The [`TensorNetworkModel`](@ref) to query.
139+
- `usecuda`: Specifies whether to use CUDA for tensor contraction.
140+
- `rescale`: Specifies whether to rescale the tensors during contraction.
141+
142+
### Example
143+
The following example is taken from [`examples/asia/main.jl`](@ref).
144+
145+
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146+
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));
147+
148+
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
149+
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
150+
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
151+
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
152+
153+
julia> marginals(tn)
154+
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
155+
[8] => [0.450138, 0.549863]
156+
[3] => [0.5, 0.5]
157+
[1] => [1.0]
158+
[5] => [0.45, 0.55]
159+
[4] => [0.055, 0.945]
160+
[6] => [0.10225, 0.89775]
161+
[7] => [0.145092, 0.854908]
162+
[2] => [0.05, 0.95]
163+
164+
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
165+
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
166+
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
167+
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
168+
169+
julia> marginals(tn2)
170+
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
171+
[2, 3] => [0.025 0.025; 0.475 0.475]
172+
[3, 4] => [0.05 0.45; 0.005 0.495]
173+
```
174+
175+
In this example, we first set the evidence for variable 1 to 0 and then query
176+
the marginals of all individual variables. The returned dictionary has keys
177+
that correspond to the queried variables and values that represent their
178+
marginals. These marginals are vectors, with each entry corresponding to the
179+
probability of the variable taking a specific value. In this example, the
180+
possible values are 0 or 1. For the evidence variable 1, the marginal is
181+
always [1.0] since its value is fixed at 0.
182+
183+
Next, we specify the marginal variables to query as variables 2 and 3, and
184+
variables 3 and 4, respectively. The joint marginals may or may not affect the
185+
contraction time and space. In this example, the contraction space complexity
186+
increases from 2^{2.0} to 2^{5.0}, and the contraction time complexity
187+
increases from 2^{5.977} to 2^{7.781}. The output marginals are the joint
188+
probabilities of the queried variables, represented by tensors.
189+
129190
"""
130-
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Vector
191+
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
131192
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
132193
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
133194
@debug "cost = $cost"
134195
if rescale
135-
return LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1)
196+
return Dict(zip(tn.mars, LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1)))
136197
else
137-
return LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1)
198+
return Dict(zip(tn.mars, LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1)))
138199
end
139200
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ GenericTensorNetworks = "3521c873-ad32-4bb4-b63d-f4f178f42b49"
55
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
66
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
77
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/cuda.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ CUDA.allowscalar(false)
1111
tn = TensorNetworkModel(model; optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40), evidence)
1212
@debug contraction_complexity(tn)
1313
@time marginals2 = marginals(tn; usecuda = true)
14-
@test all(x -> x isa CuArray, marginals2)
14+
@test all(x -> x.second isa CuArray, marginals2)
1515
# for dangling vertices, the output size is 1.
1616
npass = 0
1717
for i in 1:(model.nvars)
18-
npass += (length(marginals2[i]) == 1 && reference_solution[i] == [0.0, 1]) || isapprox(Array(marginals2[i]), reference_solution[i]; atol = 1e-6)
18+
npass += (length(marginals2[[i]]) == 1 && reference_solution[i] == [0.0, 1]) || isapprox(Array(marginals2[[i]]), reference_solution[i]; atol = 1e-6)
1919
end
2020
@test npass == model.nvars
2121
end

test/generictensornetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using GenericTensorNetworks, TensorInference
77
g = GenericTensorNetworks.Graphs.smallgraph(:petersen)
88
problem = IndependentSet(g)
99
model = TensorNetworkModel(problem, β; mars=[[2, 3]])
10-
mars = marginals(model)[1]
10+
mars = marginals(model)[[2, 3]]
1111
problem2 = IndependentSet(g; openvertices=[2,3])
1212
mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(problem2, PartitionFunction(β)), 1)
1313
@test mars mars2

test/mar.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
# compute marginals
3232
ti_sol = marginals(tn)
3333
ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars
34-
@test isapprox(ti_sol, ref_sol; atol = 1e-5)
34+
@test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-5)
3535
end
3636

3737
@testset "UAI Reference Solution Comparison" begin
@@ -63,7 +63,7 @@ end
6363
@debug contraction_complexity(tn)
6464
ti_sol = marginals(tn)
6565
ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars
66-
@test isapprox(ti_sol, ref_sol; atol = 1e-4)
66+
@test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-4)
6767
end
6868
end
6969
end
@@ -120,15 +120,18 @@ end
120120
mars = marginals(tnet)
121121
tnet23 = TensorNetworkModel(model; openvars=[2,3])
122122
tnet34 = TensorNetworkModel(model; openvars=[3,4])
123-
@test mars[1] probability(tnet23)
124-
@test mars[2] probability(tnet34)
123+
@test mars[[2 ,3]] probability(tnet23)
124+
@test mars[[3, 4]] probability(tnet34)
125125

126-
tnet1 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>1))
127-
tnet2 = TensorNetworkModel(model; mars=[[2, 3], [3, 4]], evidence=Dict(3=>0))
126+
vars = [[2, 4], [3, 5]]
127+
tnet1 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>1))
128+
tnet2 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>0))
128129
mars1 = marginals(tnet1)
129130
mars2 = marginals(tnet2)
130131
update_evidence!(tnet1, Dict(3=>0))
131132
mars1b = marginals(tnet1)
132-
@test !(mars1 mars2)
133-
@test mars1b mars2
133+
for k in vars
134+
@test !(mars1[k] mars2[k])
135+
@test mars1b[k] mars2[k]
136+
end
134137
end

test/sampling.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ using TensorInference, Test
4949
n = 10000
5050
tnet = TensorNetworkModel(model)
5151
samples = sample(tnet, n)
52-
mars = getindex.(marginals(tnet), 2)
52+
mars = marginals(tnet)
5353
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
54-
@test isapprox(mars, mars_sample, atol=0.05)
54+
@test isapprox([mars[[i]][2] for i=1:8], mars_sample, atol=0.05)
5555

5656
# fix the evidence
5757
tnet = TensorNetworkModel(model, optimizer=TreeSA(), evidence=Dict(7=>1))
5858
samples = sample(tnet, n)
59-
mars = getindex.(marginals(tnet), 1)
59+
mars = marginals(tnet)
6060
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
61-
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
61+
@test isapprox([[mars[[i]][1] for i=1:6]..., mars[[8]][1]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
6262
end

0 commit comments

Comments
 (0)