Skip to content

Commit 8ab2807

Browse files
committed
use matrix instead of vectors
1 parent 9c43716 commit 8ab2807

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

src/sampling.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The sampled configurations are stored in `samples`, which is a vector of vector.
1010
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1111
"""
1212
struct Samples{L}
13-
samples::Vector{Vector{Int}}
13+
samples::Matrix{Int} # size is nvars × nsample
1414
labels::Vector{L}
1515
setmask::BitVector
1616
end
@@ -50,13 +50,14 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
5050
code = DynamicEinCode(newixs, newiy)
5151

5252
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
53-
for (i, sample) in enumerate(samples.samples)
53+
for i in axes(samples.samples, 2)
54+
sample = samples.samples[:, i]
5455
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
5556
newy = get_element(y, slice_y_dim, sample[iy_in_sample])
5657
probabilities = einsum(code, (newxs...,), size_dict) / newy
5758
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
5859
# update the samples
59-
samples.samples[i][eliminated_locs] .= config.I .- 1
60+
samples.samples[eliminated_locs, i] .= config.I .- 1
6061
end
6162
return samples
6263
end
@@ -79,7 +80,7 @@ Returns a vector of vector, each element being a configurations defined on `get_
7980
* `tn` is the tensor network model.
8081
* `n` is the number of samples to be returned.
8182
"""
82-
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{Int}}
83+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix{Int}
8384
# generate tropical tensors with its elements being log(p).
8485
xs = adapt_tensors(tn; usecuda, rescale = false)
8586
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
@@ -93,20 +94,17 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{
9394
idx = map(l->findfirst(==(l), labels), iy)
9495
setmask[idx] .= true
9596
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
96-
configs = map(indices) do ind
97-
c=zeros(Int, length(labels))
98-
c[idx] .= ind.I .- 1
99-
c
97+
configs = zeros(Int, length(labels), n)
98+
for i=1:n
99+
configs[idx, i] .= indices[i].I .- 1
100100
end
101101
samples = Samples(configs, labels, setmask)
102102
# back-propagate
103103
generate_samples(tn.code, cache, samples, size_dict)
104104
# set evidence variables
105105
for (k, v) in tn.fixedvertices
106106
idx = findfirst(==(k), labels)
107-
for c in samples.samples
108-
c[idx] = v
109-
end
107+
samples.samples[idx, :] .= v
110108
end
111109
return samples.samples
112110
end

test/sampling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ using TensorInference, Test
5151
tnet = TensorNetworkModel(instance)
5252
samples = sample(tnet, n)
5353
mars = getindex.(marginals(tnet), 2)
54-
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
54+
mars_sample = [count(i->samples[k, i]==(1), axes(samples, 2)) for k=1:8] ./ n
5555
@test isapprox(mars, mars_sample, atol=0.05)
5656

5757
# fix the evidence
5858
set_evidence!(instance, 7=>1)
5959
tnet = TensorNetworkModel(instance)
6060
samples = sample(tnet, n)
6161
mars = getindex.(marginals(tnet), 1)
62-
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
62+
mars_sample = [count(i->samples[k, i]==(0), axes(samples, 2)) for k=1:8] ./ n
6363
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
6464
end

0 commit comments

Comments
 (0)