Skip to content

Commit 550867a

Browse files
authored
Merge pull request #30 from TensorBFS/jg/sampling
Improve the performance of sampling
2 parents a3c3181 + a55eabc commit 550867a

File tree

4 files changed

+114
-79
lines changed

4 files changed

+114
-79
lines changed

src/sampling.jl

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The sampled configurations are stored in `samples`, which is a vector of vector.
1010
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1111
"""
1212
struct Samples{L}
13-
samples::Vector{Vector{Int}}
13+
samples::Matrix{Int} # size is nvars × nsample
1414
labels::Vector{L}
1515
setmask::BitVector
1616
end
@@ -23,7 +23,7 @@ function setmask!(samples::Samples, eliminated_variables)
2323
return samples
2424
end
2525

26-
idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
26+
idx4labels(totalset, labels)::Vector{Int} = map(v->findfirst(==(v), totalset), labels)
2727

2828
"""
2929
$(TYPEDSIGNATURES)
@@ -41,32 +41,52 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
4141
setmask!(samples, eliminated_variables)
4242

4343
# the contraction code to get probability
44-
newiy = eliminated_variables
45-
iy_in_sample = idx4labels(samples.labels, iy)
46-
slice_y_dim = collect(1:length(iy))
4744
newixs = map(ix->setdiff(ix, iy), ixs)
4845
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
4946
slice_xs_dim = map(ix->idx4labels(ix, ix iy), ixs)
50-
code = DynamicEinCode(newixs, newiy)
47+
48+
# relabel and compute probabilities
49+
uniquelabels = unique!(vcat(ixs..., iy))
50+
labelmap = Dict(zip(uniquelabels, 1:length(uniquelabels)))
51+
batchdim = length(labelmap) + 1
52+
newnewixs = [Int[getindex.(Ref(labelmap), ix)..., batchdim] for ix in newixs]
53+
newnewiy = Int[getindex.(Ref(labelmap), eliminated_variables)..., batchdim]
54+
newnewxs = [get_slice(x, dimx, samples.samples[ixloc, :]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
55+
code = DynamicEinCode(newnewixs, newnewiy)
56+
probabilities = code(newnewxs...)
5157

5258
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
53-
for (i, sample) in enumerate(samples.samples)
54-
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
55-
newy = get_element(y, slice_y_dim, sample[iy_in_sample])
56-
probabilities = einsum(code, (newxs...,), size_dict) / newy
57-
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
58-
# update the samples
59-
samples.samples[i][eliminated_locs] .= config.I .- 1
59+
for i=axes(samples.samples, 2)
60+
config = StatsBase.sample(totalset, Weights(vec(selectdim(probabilities, ndims(probabilities), i))))
61+
# update the samplesS
62+
samples.samples[eliminated_locs, i] .= config.I .- 1
6063
end
6164
return samples
6265
end
6366

6467
# type unstable
65-
function get_slice(x, dim, config)
66-
asarray(x[[i dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x)
68+
function get_slice(x::AbstractArray{T}, slicedim, configs::AbstractMatrix) where T
69+
outdim = setdiff(1:ndims(x), slicedim)
70+
res = similar(x, [size(x, d) for d in outdim]..., size(configs, 2))
71+
return get_slice!(res, x, outdim, slicedim, configs)
6772
end
68-
function get_element(x, dim, config)
69-
x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...]
73+
74+
function get_slice!(res, x::AbstractArray{T}, outdim, slicedim, configs::AbstractMatrix) where T
75+
xstrides = strides(x)
76+
@inbounds for ci in CartesianIndices(res)
77+
idx = 1
78+
# the output dimension part
79+
for (dim, k) in zip(outdim, ci.I)
80+
idx += (k-1) * xstrides[dim]
81+
end
82+
# the sliced part
83+
batchidx = ci.I[end]
84+
for (dim, k) in zip(slicedim, view(configs, :, batchidx))
85+
idx += k * xstrides[dim]
86+
end
87+
res[ci] = x[idx]
88+
end
89+
return res
7090
end
7191

7292
"""
@@ -79,7 +99,7 @@ Returns a vector of vector, each element being a configurations defined on `get_
7999
* `tn` is the tensor network model.
80100
* `n` is the number of samples to be returned.
81101
"""
82-
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{Int}}
102+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix{Int}
83103
# generate tropical tensors with its elements being log(p).
84104
xs = adapt_tensors(tn; usecuda, rescale = false)
85105
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
@@ -93,21 +113,27 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{
93113
idx = map(l->findfirst(==(l), labels), iy)
94114
setmask[idx] .= true
95115
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
96-
configs = map(indices) do ind
97-
c=zeros(Int, length(labels))
98-
c[idx] .= ind.I .- 1
99-
c
116+
configs = zeros(Int, length(labels), n)
117+
for i=1:n
118+
configs[idx, i] .= indices[i].I .- 1
100119
end
101120
samples = Samples(configs, labels, setmask)
102121
# back-propagate
103122
generate_samples(tn.code, cache, samples, size_dict)
123+
# set evidence variables
124+
for (k, v) in tn.fixedvertices
125+
idx = findfirst(==(k), labels)
126+
samples.samples[idx, :] .= v
127+
end
104128
return samples.samples
105129
end
106130

107131
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
108132
if !OMEinsum.isleaf(code)
109133
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
110134
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
111-
generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict))
135+
for (arg, sib) in zip(code.args, cache.siblings)
136+
generate_samples(arg, sib, samples, size_dict)
137+
end
112138
end
113139
end

test/mmap.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,58 +28,4 @@ end
2828
@debug(mmap3)
2929
logp, config = most_probable_config(mmap3)
3030
@test log_probability(mmap3, config) logp
31-
end
32-
33-
@testset "sampling" begin
34-
instance = TensorInference.read_instance_from_string("""MARKOV
35-
8
36-
2 2 2 2 2 2 2 2
37-
8
38-
1 0
39-
2 1 0
40-
1 2
41-
2 3 2
42-
2 4 2
43-
3 5 3 1
44-
2 6 5
45-
3 7 5 4
46-
47-
2
48-
0.01
49-
0.99
50-
51-
4
52-
0.05 0.01
53-
0.95 0.99
54-
55-
2
56-
0.5
57-
0.5
58-
59-
4
60-
0.1 0.01
61-
0.9 0.99
62-
63-
4
64-
0.6 0.3
65-
0.4 0.7
66-
67-
8
68-
1 1 1 0
69-
0 0 0 1
70-
71-
4
72-
0.98 0.05
73-
0.02 0.95
74-
75-
8
76-
0.9 0.7 0.8 0.1
77-
0.1 0.3 0.2 0.9
78-
""")
79-
n = 10000
80-
tnet = TensorNetworkModel(instance)
81-
samples = sample(tnet, n)
82-
mars = getindex.(marginals(tnet), 2)
83-
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
84-
@test isapprox(mars, mars_sample, atol=0.05)
85-
end
31+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test, TensorInference, Documenter, Pkg, Artifacts
22

33
import Pkg;
4-
Pkg.ensure_artifact_installed("uai2014", "Artifacts.toml");
4+
Pkg.ensure_artifact_installed("uai2014", joinpath(@__DIR__, "Artifacts.toml"));
55

66
include("utils.jl")
77

test/sampling.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using TensorInference, Test
2+
3+
@testset "sampling" begin
4+
instance = TensorInference.read_instance_from_string("""MARKOV
5+
8
6+
2 2 2 2 2 2 2 2
7+
8
8+
1 0
9+
2 1 0
10+
1 2
11+
2 3 2
12+
2 4 2
13+
3 5 3 1
14+
2 6 5
15+
3 7 5 4
16+
17+
2
18+
0.01
19+
0.99
20+
21+
4
22+
0.05 0.01
23+
0.95 0.99
24+
25+
2
26+
0.5
27+
0.5
28+
29+
4
30+
0.1 0.01
31+
0.9 0.99
32+
33+
4
34+
0.6 0.3
35+
0.4 0.7
36+
37+
8
38+
1 1 1 0
39+
0 0 0 1
40+
41+
4
42+
0.98 0.05
43+
0.02 0.95
44+
45+
8
46+
0.9 0.7 0.8 0.1
47+
0.1 0.3 0.2 0.9
48+
""")
49+
n = 10000
50+
tnet = TensorNetworkModel(instance)
51+
samples = sample(tnet, n)
52+
mars = getindex.(marginals(tnet), 2)
53+
mars_sample = [count(i->samples[k, i]==(1), axes(samples, 2)) for k=1:8] ./ n
54+
@test isapprox(mars, mars_sample, atol=0.05)
55+
56+
# fix the evidence
57+
set_evidence!(instance, 7=>1)
58+
tnet = TensorNetworkModel(instance)
59+
samples = sample(tnet, n)
60+
mars = getindex.(marginals(tnet), 1)
61+
mars_sample = [count(i->samples[k, i]==(0), axes(samples, 2)) for k=1:8] ./ n
62+
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
63+
end

0 commit comments

Comments
 (0)