@@ -23,7 +23,7 @@ function setmask!(samples::Samples, eliminated_variables)
2323 return samples
2424end
2525
26- idx4labels (totalset, labels) = map (v-> findfirst (== (v), totalset), labels)
26+ idx4labels (totalset, labels):: Vector{Int} = map (v-> findfirst (== (v), totalset), labels)
2727
2828"""
2929$(TYPEDSIGNATURES)
@@ -41,32 +41,52 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
4141 setmask! (samples, eliminated_variables)
4242
4343 # the contraction code to get probability
44- newiy = eliminated_variables
45- iy_in_sample = idx4labels (samples. labels, iy)
46- slice_y_dim = collect (1 : length (iy))
4744 newixs = map (ix-> setdiff (ix, iy), ixs)
4845 ix_in_sample = map (ix-> idx4labels (samples. labels, ix ∩ iy), ixs)
4946 slice_xs_dim = map (ix-> idx4labels (ix, ix ∩ iy), ixs)
50- code = DynamicEinCode (newixs, newiy)
47+
48+ # relabel and compute probabilities
49+ uniquelabels = unique! (vcat (ixs... , iy))
50+ labelmap = Dict (zip (uniquelabels, 1 : length (uniquelabels)))
51+ batchdim = length (labelmap) + 1
52+ newnewixs = [Int[getindex .(Ref (labelmap), ix)... , batchdim] for ix in newixs]
53+ newnewiy = Int[getindex .(Ref (labelmap), eliminated_variables)... , batchdim]
54+ newnewxs = [get_slice (x, dimx, samples. samples[ixloc, :]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
55+ code = DynamicEinCode (newnewixs, newnewiy)
56+ probabilities = code (newnewxs... )
5157
5258 totalset = CartesianIndices ((map (x-> size_dict[x], eliminated_variables)... ,))
53- for i in axes (samples. samples, 2 )
54- newxs = [get_slice (x, dimx, samples. samples[ixloc, i]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
55- newy = get_element (y, slice_y_dim, samples. samples[iy_in_sample, i])
56- probabilities = einsum (code, (newxs... ,), size_dict) / newy
57- config = StatsBase. sample (totalset, Weights (vec (probabilities)))
58- # update the samples
59+ for i= axes (samples. samples, 2 )
60+ config = StatsBase. sample (totalset, Weights (vec (selectdim (probabilities, ndims (probabilities), i))))
61+ # update the samplesS
5962 samples. samples[eliminated_locs, i] .= config. I .- 1
6063 end
6164 return samples
6265end
6366
6467# type unstable
65- function get_slice (x, dim, config)
66- asarray (x[[i ∈ dim ? config[findfirst (== (i), dim)]+ 1 : Colon () for i in 1 : ndims (x)]. .. ], x)
68+ function get_slice (x:: AbstractArray{T} , slicedim, configs:: AbstractMatrix ) where T
69+ outdim = setdiff (1 : ndims (x), slicedim)
70+ res = similar (x, [size (x, d) for d in outdim]. .. , size (configs, 2 ))
71+ return get_slice! (res, x, outdim, slicedim, configs)
6772end
68- function get_element (x, dim, config)
69- x[[config[findfirst (== (i), dim)]+ 1 for i in 1 : ndims (x)]. .. ]
73+
74+ function get_slice! (res, x:: AbstractArray{T} , outdim, slicedim, configs:: AbstractMatrix ) where T
75+ xstrides = strides (x)
76+ @inbounds for ci in CartesianIndices (res)
77+ idx = 1
78+ # the output dimension part
79+ for (dim, k) in zip (outdim, ci. I)
80+ idx += (k- 1 ) * xstrides[dim]
81+ end
82+ # the sliced part
83+ batchidx = ci. I[end ]
84+ for (dim, k) in zip (slicedim, view (configs, :, batchidx))
85+ idx += k * xstrides[dim]
86+ end
87+ res[ci] = x[idx]
88+ end
89+ return res
7090end
7191
7292"""
@@ -112,6 +132,8 @@ function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size
112132 if ! OMEinsum. isleaf (code)
113133 xs = ntuple (i -> cache. siblings[i]. content, length (cache. siblings))
114134 backward_sampling! (OMEinsum. getixs (code. eins), xs, OMEinsum. getiy (code. eins), cache. content, samples, size_dict)
115- generate_samples .(code. args, cache. siblings, Ref (samples), Ref (size_dict))
135+ for (arg, sib) in zip (code. args, cache. siblings)
136+ generate_samples (arg, sib, samples, size_dict)
137+ end
116138 end
117139end
0 commit comments