@@ -10,7 +10,7 @@ The sampled configurations are stored in `samples`, which is a vector of vector.
1010The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1111"""
1212struct Samples{L}
13- samples:: Vector{Vector{ Int}}
13+ samples:: Matrix{ Int} # size is nvars × nsample
1414 labels:: Vector{L}
1515 setmask:: BitVector
1616end
@@ -23,7 +23,7 @@ function setmask!(samples::Samples, eliminated_variables)
2323 return samples
2424end
2525
26- idx4labels (totalset, labels) = map (v-> findfirst (== (v), totalset), labels)
26+ idx4labels (totalset, labels):: Vector{Int} = map (v-> findfirst (== (v), totalset), labels)
2727
2828"""
2929$(TYPEDSIGNATURES)
@@ -41,32 +41,52 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
4141 setmask! (samples, eliminated_variables)
4242
4343 # the contraction code to get probability
44- newiy = eliminated_variables
45- iy_in_sample = idx4labels (samples. labels, iy)
46- slice_y_dim = collect (1 : length (iy))
4744 newixs = map (ix-> setdiff (ix, iy), ixs)
4845 ix_in_sample = map (ix-> idx4labels (samples. labels, ix ∩ iy), ixs)
4946 slice_xs_dim = map (ix-> idx4labels (ix, ix ∩ iy), ixs)
50- code = DynamicEinCode (newixs, newiy)
47+
48+ # relabel and compute probabilities
49+ uniquelabels = unique! (vcat (ixs... , iy))
50+ labelmap = Dict (zip (uniquelabels, 1 : length (uniquelabels)))
51+ batchdim = length (labelmap) + 1
52+ newnewixs = [Int[getindex .(Ref (labelmap), ix)... , batchdim] for ix in newixs]
53+ newnewiy = Int[getindex .(Ref (labelmap), eliminated_variables)... , batchdim]
54+ newnewxs = [get_slice (x, dimx, samples. samples[ixloc, :]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
55+ code = DynamicEinCode (newnewixs, newnewiy)
56+ probabilities = code (newnewxs... )
5157
5258 totalset = CartesianIndices ((map (x-> size_dict[x], eliminated_variables)... ,))
53- for (i, sample) in enumerate (samples. samples)
54- newxs = [get_slice (x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
55- newy = get_element (y, slice_y_dim, sample[iy_in_sample])
56- probabilities = einsum (code, (newxs... ,), size_dict) / newy
57- config = StatsBase. sample (totalset, Weights (vec (probabilities)))
58- # update the samples
59- samples. samples[i][eliminated_locs] .= config. I .- 1
59+ for i= axes (samples. samples, 2 )
60+ config = StatsBase. sample (totalset, Weights (vec (selectdim (probabilities, ndims (probabilities), i))))
61+ # update the samplesS
62+ samples. samples[eliminated_locs, i] .= config. I .- 1
6063 end
6164 return samples
6265end
6366
6467# type unstable
65- function get_slice (x, dim, config)
66- asarray (x[[i ∈ dim ? config[findfirst (== (i), dim)]+ 1 : Colon () for i in 1 : ndims (x)]. .. ], x)
68+ function get_slice (x:: AbstractArray{T} , slicedim, configs:: AbstractMatrix ) where T
69+ outdim = setdiff (1 : ndims (x), slicedim)
70+ res = similar (x, [size (x, d) for d in outdim]. .. , size (configs, 2 ))
71+ return get_slice! (res, x, outdim, slicedim, configs)
6772end
68- function get_element (x, dim, config)
69- x[[config[findfirst (== (i), dim)]+ 1 for i in 1 : ndims (x)]. .. ]
73+
74+ function get_slice! (res, x:: AbstractArray{T} , outdim, slicedim, configs:: AbstractMatrix ) where T
75+ xstrides = strides (x)
76+ @inbounds for ci in CartesianIndices (res)
77+ idx = 1
78+ # the output dimension part
79+ for (dim, k) in zip (outdim, ci. I)
80+ idx += (k- 1 ) * xstrides[dim]
81+ end
82+ # the sliced part
83+ batchidx = ci. I[end ]
84+ for (dim, k) in zip (slicedim, view (configs, :, batchidx))
85+ idx += k * xstrides[dim]
86+ end
87+ res[ci] = x[idx]
88+ end
89+ return res
7090end
7191
7292"""
@@ -79,7 +99,7 @@ Returns a vector of vector, each element being a configurations defined on `get_
7999* `tn` is the tensor network model.
80100* `n` is the number of samples to be returned.
81101"""
82- function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false ):: Vector{Vector{ Int} }
102+ function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false ):: AbstractMatrix{ Int}
83103 # generate tropical tensors with its elements being log(p).
84104 xs = adapt_tensors (tn; usecuda, rescale = false )
85105 # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
@@ -93,21 +113,27 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{
93113 idx = map (l-> findfirst (== (l), labels), iy)
94114 setmask[idx] .= true
95115 indices = StatsBase. sample (CartesianIndices (size (cache. content)), Weights (normalize! (vec (LinearAlgebra. normalize! (cache. content)))), n)
96- configs = map (indices) do ind
97- c= zeros (Int, length (labels))
98- c[idx] .= ind. I .- 1
99- c
116+ configs = zeros (Int, length (labels), n)
117+ for i= 1 : n
118+ configs[idx, i] .= indices[i]. I .- 1
100119 end
101120 samples = Samples (configs, labels, setmask)
102121 # back-propagate
103122 generate_samples (tn. code, cache, samples, size_dict)
123+ # set evidence variables
124+ for (k, v) in tn. fixedvertices
125+ idx = findfirst (== (k), labels)
126+ samples. samples[idx, :] .= v
127+ end
104128 return samples. samples
105129end
106130
107131function generate_samples (code:: NestedEinsum , cache:: CacheTree{T} , samples, size_dict:: Dict ) where {T}
108132 if ! OMEinsum. isleaf (code)
109133 xs = ntuple (i -> cache. siblings[i]. content, length (cache. siblings))
110134 backward_sampling! (OMEinsum. getixs (code. eins), xs, OMEinsum. getiy (code. eins), cache. content, samples, size_dict)
111- generate_samples .(code. args, cache. siblings, Ref (samples), Ref (size_dict))
135+ for (arg, sib) in zip (code. args, cache. siblings)
136+ generate_samples (arg, sib, samples, size_dict)
137+ end
112138 end
113139end
0 commit comments