@@ -43,6 +43,7 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
4343end
4444
4545function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
46+ @assert length (ix) == N
4647 idx = ntuple (N) do i
4748 if ix[i] ∈ el. first
4849 k = el. second[findfirst (== (ix[i]), el. first)] + 1
@@ -51,7 +52,6 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
5152 1 : size (x, i)
5253 end
5354 end
54- @show idx
5555 return asarray (x[idx... ], x)
5656end
5757
@@ -143,62 +143,48 @@ function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractA
143143end
144144
145145# pool is a vector of labels that are not eliminated yet.
146- function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples{L} , pool:: Vector{L} , size_dict:: Dict ) where {T, L}
146+ function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples{L} , pool:: Vector{L} , size_dict:: Dict{L} ) where {T, L}
147147 if ! (OMEinsum. isleaf (code))
148- ixs, iy = getixsv (code), getiyv (code)
148+ ixs, iy = getixsv (code. eins ), getiyv (code. eins )
149149 for (subcode, child, ix) in zip (code. args, cache. children, ixs)
150150 # subenv for the current child, use it to sample and update its cache
151151 siblings = filter (x-> x != = child, cache. children)
152152 siblings_ixs = filter (x-> x != = ix, ixs)
153153 envcode = optimize_code (EinCode ([siblings_ixs... , iy], ix), size_dict, GreedyMethod (; nrepeat= 1 ))
154154 subenv = einsum (envcode, (getfield .(siblings, :content )... , env), size_dict)
155155
156- # sample
156+ # get samples
157157 sample_vars = ix ∩ pool
158- update_samples! (child. content, subenv, samples, ix, sample_vars, size_dict)
158+ probabilities = einsum (DynamicEinCode ([ix, ix], sample_vars), (child. content, subenv), size_dict)
159+ update_samples! (samples, sample_vars, probabilities)
160+ pool = setdiff (pool, sample_vars)
159161
162+ # eliminate the sampled variables
163+ setindex! .(Ref (size_dict), 1 , sample_vars)
164+ subsamples = subset (samples, sample_vars)[:, 1 ]
165+ udpate_cache_tree! (code, cache, sample_vars=> subsamples)
166+ subenv = eliminate_dimensions (subenv, ix, sample_vars=> subsamples)
167+
168+ # recurse
160169 generate_samples! (subcode, child, subenv, samples, setdiff (pool, sample_vars), size_dict)
161170 end
162171 end
163172end
164173
165- """
166- $(TYPEDSIGNATURES)
167-
168- The backward process for sampling configurations.
169-
170- ### Arguments
171- * `code` is the contraction code in the current step,
172- * `env` is the environment tensor,
173- * `samples` is the samples generated for eliminated variables,
174- * `size_dict` is a key-value map from tensor label to dimension size.
175- """
176- function update_samples! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (env), samples:: Samples , size_dict)
177- ixs, iy = getixsv (code), getiyv (code)
178- el = setdiff (vcat (ixs... ), iy) ∩ samples. labels
179-
180- # get probability
181- prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
182- el_prev = eliminated_variables (samples)
183- xs = [eliminate_dimensions (x, ix, el_prev=> subset (samples, el_prev)[:,1 ]) for (ix, x) in zip (ixs, xs)]
184- probabilities = einsum (prob_code, (xs... , env), size_dict)
185-
186- # sample from the probability tensor
187- totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
188- eliminated_locs = idx4labels (samples. labels, el)
174+ # probabilities is a tensor of probabilities for each variable in `vars`.
175+ function update_samples! (samples:: Samples , vars:: AbstractVector{L} , probabilities:: AbstractArray{T, N} ) where {L, T, N}
176+ @assert length (vars) == N
177+ totalset = CartesianIndices (probabilities)
178+ eliminated_locs = idx4labels (samples. labels, vars)
189179 config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
190180 samples. samples[eliminated_locs, 1 ] .= config. I .- 1
181+ set_eliminated! (samples, vars)
182+ end
191183
192- # eliminate the sampled variables
193- set_eliminated! (samples, el)
194- setindex! .(Ref (size_dict), 1 , el)
195- sub = subset (samples, el)[:, 1 ]
196- xs = [eliminate_dimensions (x, ix, el=> sub) for (ix, x) in zip (ixs, xs)]
197-
198- # update environment
199- return map (1 : length (ixs)) do i
200- rest = setdiff (1 : length (ixs), i)
201- code = optimize_code (EinCode ([ixs[rest]. .. , iy], ixs[i]), size_dict, GreedyMethod (; nrepeat= 1 ))
202- einsum (code, (xs[rest]. .. , env), size_dict)
184+ function udpate_cache_tree! (ne:: NestedEinsum , cache:: CacheTree{T} , el:: Pair{<:AbstractVector{L}} ) where {T, L}
185+ OMEinsum. isleaf (ne) && return
186+ for (subcode, child, ix) in zip (ne. args, cache. children, getixsv (ne. eins))
187+ child. content = eliminate_dimensions (child. content, ix, el)
188+ udpate_cache_tree! (subcode, child, el)
203189 end
204- end
190+ end
0 commit comments