@@ -10,7 +10,7 @@ The sampled configurations are stored in `samples`, which is a vector of vector.
1010The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1111"""
1212struct Samples{L}
13- samples:: Vector{Vector{ Int}}
13+ samples:: Matrix{ Int} # size is nvars × nsample
1414 labels:: Vector{L}
1515 setmask:: BitVector
1616end
@@ -50,13 +50,14 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
5050 code = DynamicEinCode (newixs, newiy)
5151
5252 totalset = CartesianIndices ((map (x-> size_dict[x], eliminated_variables)... ,))
53- for (i, sample) in enumerate (samples. samples)
53+ for i in axes (samples. samples, 2 )
54+ sample = samples. samples[:, i]
5455 newxs = [get_slice (x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
5556 newy = get_element (y, slice_y_dim, sample[iy_in_sample])
5657 probabilities = einsum (code, (newxs... ,), size_dict) / newy
5758 config = StatsBase. sample (totalset, Weights (vec (probabilities)))
5859 # update the samples
59- samples. samples[i][ eliminated_locs] .= config. I .- 1
60+ samples. samples[eliminated_locs, i ] .= config. I .- 1
6061 end
6162 return samples
6263end
@@ -79,7 +80,7 @@ Returns a vector of vector, each element being a configurations defined on `get_
7980* `tn` is the tensor network model.
8081* `n` is the number of samples to be returned.
8182"""
82- function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false ):: Vector{Vector{ Int} }
83+ function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false ):: AbstractMatrix{ Int}
8384 # generate tropical tensors with its elements being log(p).
8485 xs = adapt_tensors (tn; usecuda, rescale = false )
8586 # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
@@ -93,20 +94,17 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{
9394 idx = map (l-> findfirst (== (l), labels), iy)
9495 setmask[idx] .= true
9596 indices = StatsBase. sample (CartesianIndices (size (cache. content)), Weights (normalize! (vec (LinearAlgebra. normalize! (cache. content)))), n)
96- configs = map (indices) do ind
97- c= zeros (Int, length (labels))
98- c[idx] .= ind. I .- 1
99- c
97+ configs = zeros (Int, length (labels), n)
98+ for i= 1 : n
99+ configs[idx, i] .= indices[i]. I .- 1
100100 end
101101 samples = Samples (configs, labels, setmask)
102102 # back-propagate
103103 generate_samples (tn. code, cache, samples, size_dict)
104104 # set evidence variables
105105 for (k, v) in tn. fixedvertices
106106 idx = findfirst (== (k), labels)
107- for c in samples. samples
108- c[idx] = v
109- end
107+ samples. samples[idx, :] .= v
110108 end
111109 return samples. samples
112110end
0 commit comments