2525Base. getindex (s:: Samples , i:: Int ) = view (s. samples, :, i)
2626Base. length (s:: Samples ) = size (s. samples, 2 )
2727Base. size (s:: Samples ) = (size (s. samples, 2 ),)
28+ function Base. show (io:: IO , s:: Samples ) # display with PrettyTables
29+ println (io, typeof (s))
30+ PrettyTables. pretty_table (io, s. samples' , header= s. labels)
31+ end
32+ num_samples (samples:: Samples ) = size (samples. samples, 2 )
2833eliminated_variables (samples:: Samples ) = samples. labels[samples. setmask]
29- idx4labels (totalset, labels):: Vector{Int} = map (v-> findfirst (== (v), totalset), labels)
34+ is_eliminated (samples:: Samples{L} , var:: L ) where L = samples. setmask[findfirst (== (var), samples. labels)]
35+ function idx4labels (totalset:: AbstractVector{L} , labels:: AbstractVector{L} ):: Vector{Int} where L
36+ map (v-> findfirst (== (v), totalset), labels)
37+ end
38+ idx4labels (samples:: Samples{L} , lb:: L ) where L = findfirst (== (lb), samples. labels)
39+ function subset (samples:: Samples{L} , labels:: AbstractVector{L} ) where L
40+ idx = idx4labels (samples. labels, labels)
41+ @assert all (i-> samples. setmask[i], idx)
42+ return samples. samples[idx, :]
43+ end
3044
3145"""
3246$(TYPEDSIGNATURES)
@@ -39,34 +53,49 @@ The backward process for sampling configurations.
3953* `samples` is the samples generated for eliminated variables,
4054* `size_dict` is a key-value map from tensor label to dimension size.
4155"""
42- function backward_sampling! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (y), @nospecialize ( env), samples:: Samples , size_dict)
56+ function backward_sampling! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (env), samples:: Samples , size_dict)
4357 ixs, iy = getixsv (code), getiyv (code)
44- el = setdiff (vcat (ixs... ), iy)
58+ el = setdiff (vcat (ixs... ), iy) ∩ samples. labels
59+
4560 # get probability
4661 prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
62+ el_prev = eliminated_variables (samples)
63+ xs = [eliminate_dimensions (x, ix, el_prev=> subset (samples, el_prev)[:,1 ]) for (ix, x) in zip (ixs, xs)]
4764 probabilities = einsum (prob_code, (xs... , env), size_dict)
4865
4966 # sample from the probability tensor
5067 totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
5168 eliminated_locs = idx4labels (samples. labels, el)
52- for i= axes (samples. samples, 2 )
53- config = StatsBase. sample (totalset, Weights (vec (selectdim (probabilities, ndims (probabilities), i))))
54- samples. samples[eliminated_locs, i] .= config. I .- 1
55- end
69+ config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
70+ samples. samples[eliminated_locs, 1 ] .= config. I .- 1
5671
5772 # eliminate the sampled variables
5873 set_eliminated! (samples, el)
5974 for l in el
6075 size_dict[l] = 1
6176 end
62- for sample in sampels
63- map (x -> eliminate_dimensions! (x, el=> sample), xs)
64- end
77+ sub = subset (samples, el)[:, 1 ]
78+ xs = [ eliminate_dimensions (x, ix, el=> sub) for (ix, x) in zip (ixs, xs)]
79+ env = eliminate_dimensions (env, iy, el => sub)
6580
6681 # update environment
67- for (i, ix) in enumerate (ixs)
82+ return map (1 : length (ixs)) do i
83+ rest = setdiff (1 : length (ixs), i)
84+ code = optimize_code (EinCode ([ixs[rest]. .. , iy], ixs[i]), size_dict, GreedyMethod (; nrepeat= 1 ))
85+ einsum (code, (xs[rest]. .. , env), size_dict)
6886 end
69- return envs
87+ end
88+
89+ function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
90+ idx = ntuple (N) do i
91+ if ix[i] ∈ el. first
92+ k = el. second[findfirst (== (ix[i]), el. first)] + 1
93+ k: k
94+ else
95+ 1 : size (x, i)
96+ end
97+ end
98+ return asarray (x[idx... ], x)
7099end
71100
72101function addbatch (samples:: Samples , eliminated_variables)
@@ -113,48 +142,54 @@ Returns a vector of vector, each element being a configurations defined on `get_
113142* `tn` is the tensor network model.
114143* `n` is the number of samples to be returned.
115144"""
116- function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false ):: Samples
145+ function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false , queryvars = get_vars (tn) ):: Samples
117146 # generate tropical tensors with its elements being log(p).
118147 xs = adapt_tensors (tn; usecuda, rescale = false )
119148 # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
120149 size_dict = OMEinsum. get_size_dict! (getixsv (tn. code), xs, Dict {Int, Int} ())
121150 # forward compute and cache intermediate results.
122151 cache = cached_einsum (tn. code, xs, size_dict)
123152 # initialize `y̅` as the initial batch of samples.
124- labels = get_vars (tn)
125153 iy = getiyv (tn. code)
126- setmask = falses (length (labels ))
127- idx = map (l-> findfirst (== (l), labels ), iy)
154+ setmask = falses (length (queryvars ))
155+ idx = map (l-> findfirst (== (l), queryvars ), iy ∩ queryvars )
128156 setmask[idx] .= true
129- indices = StatsBase. sample (CartesianIndices (size (cache. content)), Weights ( normalize! ( vec (LinearAlgebra . normalize! ( cache. content)) )), n)
130- configs = zeros (Int, length (labels ), n)
157+ indices = StatsBase. sample (CartesianIndices (size (cache. content)), _Weights ( vec (cache. content)), n)
158+ configs = zeros (Int, length (queryvars ), n)
131159 for i= 1 : n
132160 configs[idx, i] .= indices[i]. I .- 1
133161 end
134- samples = Samples (configs, labels , setmask)
162+ samples = Samples (configs, queryvars , setmask)
135163 # back-propagate
136- generate_samples (tn. code, cache, samples, size_dict)
164+ env = copy (cache. content)
165+ fill! (env, one (eltype (env)))
166+ generate_samples! (tn. code, cache, env, samples, size_dict)
137167 # set evidence variables
138168 for (k, v) in tn. evidence
139- idx = findfirst (== (k), labels)
169+ idx = findfirst (== (k), samples . labels)
140170 samples. samples[idx, :] .= v
141171 end
142172 return samples
143173end
174+ _Weights (x:: AbstractVector{<:Real} ) = Weights (x)
175+ function _Weights (x:: AbstractArray{<:Complex} )
176+ @assert all (e-> abs (imag (e)) < 100 * eps (abs (e)), x)
177+ return Weights (real .(x))
178+ end
144179
145- function generate_samples (se:: SlicedEinsum , cache:: CacheTree{T} , samples, size_dict:: Dict ) where {T}
180+ function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , env :: AbstractArray {T} , samples, size_dict:: Dict ) where {T}
146181 # slicing is not supported yet.
147182 if length (se. slicing) != 0
148183 @warn " Slicing is not supported for caching, got nslices = $(length (se. slicing)) ! Fallback to `NestedEinsum`."
149184 end
150- return generate_samples (se. eins, cache, samples, size_dict)
185+ return generate_samples! (se. eins, cache, env , samples, size_dict)
151186end
152- function generate_samples (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples, size_dict:: Dict ) where {T}
187+ function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples , size_dict:: Dict ) where {T}
153188 if ! OMEinsum. isleaf (code)
154- xs = ntuple (i -> cache. siblings [i]. content, length (cache. siblings ))
155- envs = backward_sampling! (code. eins, xs, cache . content, env, samples, copy ( size_dict) )
156- for (arg, sib, env) in zip (code. args, cache. siblings , envs)
157- generate_samples (arg, sib, env, samples, size_dict)
189+ xs = ntuple (i -> cache. children [i]. content, length (cache. children ))
190+ envs = backward_sampling! (code. eins, xs, env, samples, size_dict)
191+ for (arg, sib, env) in zip (code. args, cache. children , envs)
192+ generate_samples! (arg, sib, env, samples, size_dict)
158193 end
159194 end
160195end
0 commit comments