@@ -14,7 +14,7 @@ struct Samples{L} <: AbstractVector{SubArray{Float64, 1, Matrix{Float64}, Tuple{
1414 labels:: Vector{L}
1515 setmask:: BitVector
1616end
17- function setmask ! (samples:: Samples , eliminated_variables)
17+ function set_eliminated ! (samples:: Samples , eliminated_variables)
1818 for var in eliminated_variables
1919 loc = findfirst (== (var), samples. labels)
2020 samples. setmask[loc] && error (" varaible `$var ` is already eliminated." )
2525Base. getindex (s:: Samples , i:: Int ) = view (s. samples, :, i)
2626Base. length (s:: Samples ) = size (s. samples, 2 )
2727Base. size (s:: Samples ) = (size (s. samples, 2 ),)
28+ eliminated_variables (samples:: Samples ) = samples. labels[samples. setmask]
2829idx4labels (totalset, labels):: Vector{Int} = map (v-> findfirst (== (v), totalset), labels)
2930
3031"""
3132$(TYPEDSIGNATURES)
3233
3334The backward process for sampling configurations.
3435
35- * `ixs` and `xs` are labels and tensor data for input tensors,
36- * `iy` and `y` are labels and tensor data for the output tensor,
36+ ### Arguments
37+ * `code` is the contraction code in the current step,
38+ * `env` is the environment tensor,
3739* `samples` is the samples generated for eliminated variables,
3840* `size_dict` is a key-value map from tensor label to dimension size.
3941"""
40- function backward_sampling! (ixs, @nospecialize (xs:: Tuple ), iy, @nospecialize (y), samples:: Samples , size_dict)
41- eliminated_variables = setdiff (vcat (ixs... ), iy)
42- eliminated_locs = idx4labels (samples. labels, eliminated_variables)
43- setmask! (samples, eliminated_variables)
42+ function backward_sampling! (code:: EinCode , @nospecialize (xs:: Tuple ), @nospecialize (y), @nospecialize (env), samples:: Samples , size_dict)
43+ ixs, iy = getixsv (code), getiyv (code)
44+ el = setdiff (vcat (ixs... ), iy)
45+ # get probability
46+ prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
47+ probabilities = einsum (prob_code, (xs... , env), size_dict)
4448
45- # the contraction code to get probability
46- newixs = map (ix-> setdiff (ix, iy), ixs)
47- ix_in_sample = map (ix-> idx4labels (samples. labels, ix ∩ iy), ixs)
48- slice_xs_dim = map (ix-> idx4labels (ix, ix ∩ iy), ixs)
49+ # sample from the probability tensor
50+ totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
51+ eliminated_locs = idx4labels (samples. labels, el)
52+ for i= axes (samples. samples, 2 )
53+ config = StatsBase. sample (totalset, Weights (vec (selectdim (probabilities, ndims (probabilities), i))))
54+ samples. samples[eliminated_locs, i] .= config. I .- 1
55+ end
4956
50- # relabel and compute probabilities
57+ # eliminate the sampled variables
58+ set_eliminated! (samples, el)
59+ for l in el
60+ size_dict[l] = 1
61+ end
62+ for sample in sampels
63+ map (x-> eliminate_dimensions! (x, el=> sample), xs)
64+ end
65+
66+ # update environment
67+ for (i, ix) in enumerate (ixs)
68+ end
69+ return envs
70+ end
71+
72+ function addbatch (samples:: Samples , eliminated_variables)
5173 uniquelabels = unique! (vcat (ixs... , iy))
5274 labelmap = Dict (zip (uniquelabels, 1 : length (uniquelabels)))
5375 batchdim = length (labelmap) + 1
5476 newnewixs = [Int[getindex .(Ref (labelmap), ix)... , batchdim] for ix in newixs]
5577 newnewiy = Int[getindex .(Ref (labelmap), eliminated_variables)... , batchdim]
5678 newnewxs = [get_slice (x, dimx, samples. samples[ixloc, :]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
57- code = DynamicEinCode (newnewixs, newnewiy)
58- probabilities = code (newnewxs... )
59-
60- totalset = CartesianIndices ((map (x-> size_dict[x], eliminated_variables)... ,))
61- for i= axes (samples. samples, 2 )
62- config = StatsBase. sample (totalset, Weights (vec (selectdim (probabilities, ndims (probabilities), i))))
63- # update the samplesS
64- samples. samples[eliminated_locs, i] .= config. I .- 1
65- end
66- return samples
6779end
6880
6981# type unstable
@@ -137,12 +149,12 @@ function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_d
137149 end
138150 return generate_samples (se. eins, cache, samples, size_dict)
139151end
140- function generate_samples (code:: NestedEinsum , cache:: CacheTree{T} , samples, size_dict:: Dict ) where {T}
152+ function generate_samples (code:: NestedEinsum , cache:: CacheTree{T} , env :: AbstractArray{T} , samples, size_dict:: Dict ) where {T}
141153 if ! OMEinsum. isleaf (code)
142154 xs = ntuple (i -> cache. siblings[i]. content, length (cache. siblings))
143- backward_sampling! (OMEinsum . getixs ( code. eins) , xs, OMEinsum . getiy (code . eins), cache. content, samples, size_dict)
144- for (arg, sib) in zip (code. args, cache. siblings)
145- generate_samples (arg, sib, samples, size_dict)
155+ envs = backward_sampling! (code. eins, xs, cache. content, env, samples, copy ( size_dict) )
156+ for (arg, sib, env ) in zip (code. args, cache. siblings, envs )
157+ generate_samples (arg, sib, env, samples, size_dict)
146158 end
147159 end
148160end
0 commit comments