@@ -42,53 +42,6 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
4242 return samples. samples[idx, :]
4343end
4444
45- """
46- $(TYPEDSIGNATURES)
47-
48- The backward process for sampling configurations.
49-
50- ### Arguments
51- * `code` is the contraction code in the current step,
52- * `env` is the environment tensor,
53- * `samples` is the samples generated for eliminated variables,
54- * `size_dict` is a key-value map from tensor label to dimension size.
55- """
56- function backward_sampling! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (env), samples:: Samples , size_dict)
57- ixs, iy = getixsv (code), getiyv (code)
58- el = setdiff (vcat (ixs... ), iy) ∩ samples. labels
59-
60- # get probability
61- prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
62- el_prev = eliminated_variables (samples)
63- @show el_prev=> subset (samples, el_prev)[:,1 ]
64- xs = [eliminate_dimensions (x, ix, el_prev=> subset (samples, el_prev)[:,1 ]) for (ix, x) in zip (ixs, xs)]
65- probabilities = einsum (prob_code, (xs... , env), size_dict)
66- @show el
67- @show normalize (real .(vec (probabilities)), 1 )
68-
69- # sample from the probability tensor
70- totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
71- eliminated_locs = idx4labels (samples. labels, el)
72- config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
73- @show eliminated_locs, config. I .- 1
74- samples. samples[eliminated_locs, 1 ] .= config. I .- 1
75-
76- # eliminate the sampled variables
77- set_eliminated! (samples, el)
78- setindex! .(Ref (size_dict), 1 , el)
79- sub = subset (samples, el)[:, 1 ]
80- @show ixs, el=> sub
81- xs = [eliminate_dimensions (x, ix, el=> sub) for (ix, x) in zip (ixs, xs)]
82-
83- # update environment
84- envs = map (1 : length (ixs)) do i
85- rest = setdiff (1 : length (ixs), i)
86- code = optimize_code (EinCode ([ixs[rest]. .. , iy], ixs[i]), size_dict, GreedyMethod (; nrepeat= 1 ))
87- einsum (code, (xs[rest]. .. , env), size_dict)
88- end
89- @show envs
90- end
91-
9245function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
9346 idx = ntuple (N) do i
9447 if ix[i] ∈ el. first
@@ -167,7 +120,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
167120 # back-propagate
168121 env = copy (cache. content)
169122 fill! (env, one (eltype (env)))
170- generate_samples! (tn. code, cache, env, samples, size_dict)
123+ generate_samples! (tn. code, cache, env, samples, samples . labels, size_dict)
171124 # set evidence variables
172125 for (k, v) in tn. evidence
173126 idx = findfirst (== (k), samples. labels)
@@ -181,30 +134,71 @@ function _Weights(x::AbstractArray{<:Complex})
181134 return Weights (real .(x))
182135end
183136
184- function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples, size_dict:: Dict ) where {T}
137+ function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples, pool, size_dict:: Dict ) where {T}
185138 # slicing is not supported yet.
186139 if length (se. slicing) != 0
187140 @warn " Slicing is not supported for caching, got nslices = $(length (se. slicing)) ! Fallback to `NestedEinsum`."
188141 end
189- return generate_samples! (se. eins, cache, env, samples, size_dict)
142+ return generate_samples! (se. eins, cache, env, samples, pool, size_dict)
190143end
191- function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples , size_dict:: Dict ) where {T}
192- @info " @"
144+
145+ # pool is a vector of labels that are not eliminated yet.
146+ function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples{L} , pool:: Vector{L} , size_dict:: Dict ) where {T, L}
193147 if ! (OMEinsum. isleaf (code))
194- @info " non-leaf node"
195- @show env
196- xs = ntuple (i -> cache. children[i]. content, length (cache. children))
197- envs = backward_sampling! (code. eins, xs, env, samples, size_dict)
198- @show envs
199- fucks = map (1 : length (code. args)) do k
200- @info k
201- generate_samples! (code. args[k], cache. children[k], envs[k], samples, size_dict)
202- return " fuck"
148+ ixs, iy = getixsv (code), getiyv (code)
149+ for (subcode, child, ix) in zip (code. args, cache. children, ixs)
150+ # subenv for the current child, use it to sample and update its cache
151+ siblings = filter (x-> x != = child, cache. children)
152+ siblings_ixs = filter (x-> x != = ix, ixs)
153+ envcode = optimize_code (EinCode ([siblings_ixs... , iy], ix), size_dict, GreedyMethod (; nrepeat= 1 ))
154+ subenv = einsum (envcode, (getfield .(siblings, :content )... , env), size_dict)
155+
156+ # sample
157+ sample_vars = ix ∩ pool
158+ update_samples! (child. content, subenv, samples, ix, sample_vars, size_dict)
159+
160+ generate_samples! (subcode, child, subenv, samples, setdiff (pool, sample_vars), size_dict)
203161 end
204- @info fucks
205- return
206- else
207- @info " leaf node"
208- return
162+ end
163+ end
164+
165+ """
166+ $(TYPEDSIGNATURES)
167+
168+ The backward process for sampling configurations.
169+
170+ ### Arguments
171+ * `code` is the contraction code in the current step,
172+ * `env` is the environment tensor,
173+ * `samples` is the samples generated for eliminated variables,
174+ * `size_dict` is a key-value map from tensor label to dimension size.
175+ """
176+ function update_samples! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (env), samples:: Samples , size_dict)
177+ ixs, iy = getixsv (code), getiyv (code)
178+ el = setdiff (vcat (ixs... ), iy) ∩ samples. labels
179+
180+ # get probability
181+ prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
182+ el_prev = eliminated_variables (samples)
183+ xs = [eliminate_dimensions (x, ix, el_prev=> subset (samples, el_prev)[:,1 ]) for (ix, x) in zip (ixs, xs)]
184+ probabilities = einsum (prob_code, (xs... , env), size_dict)
185+
186+ # sample from the probability tensor
187+ totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
188+ eliminated_locs = idx4labels (samples. labels, el)
189+ config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
190+ samples. samples[eliminated_locs, 1 ] .= config. I .- 1
191+
192+ # eliminate the sampled variables
193+ set_eliminated! (samples, el)
194+ setindex! .(Ref (size_dict), 1 , el)
195+ sub = subset (samples, el)[:, 1 ]
196+ xs = [eliminate_dimensions (x, ix, el=> sub) for (ix, x) in zip (ixs, xs)]
197+
198+ # update environment
199+ return map (1 : length (ixs)) do i
200+ rest = setdiff (1 : length (ixs), i)
201+ code = optimize_code (EinCode ([ixs[rest]. .. , iy], ixs[i]), size_dict, GreedyMethod (; nrepeat= 1 ))
202+ einsum (code, (xs[rest]. .. , env), size_dict)
209203 end
210204end
0 commit comments