Skip to content

Commit 7a22b7b

Browse files
committed
improve sampling speed on CPU
1 parent 8108ab3 commit 7a22b7b

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

src/sampling.jl

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function setmask!(samples::Samples, eliminated_variables)
2323
return samples
2424
end
2525

26-
idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
26+
idx4labels(totalset, labels)::Vector{Int} = map(v->findfirst(==(v), totalset), labels)
2727

2828
"""
2929
$(TYPEDSIGNATURES)
@@ -41,32 +41,52 @@ function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y),
4141
setmask!(samples, eliminated_variables)
4242

4343
# the contraction code to get probability
44-
newiy = eliminated_variables
45-
iy_in_sample = idx4labels(samples.labels, iy)
46-
slice_y_dim = collect(1:length(iy))
4744
newixs = map(ix->setdiff(ix, iy), ixs)
4845
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
4946
slice_xs_dim = map(ix->idx4labels(ix, ix iy), ixs)
50-
code = DynamicEinCode(newixs, newiy)
47+
48+
# relabel and compute probabilities
49+
uniquelabels = unique!(vcat(ixs..., iy))
50+
labelmap = Dict(zip(uniquelabels, 1:length(uniquelabels)))
51+
batchdim = length(labelmap) + 1
52+
newnewixs = [Int[getindex.(Ref(labelmap), ix)..., batchdim] for ix in newixs]
53+
newnewiy = Int[getindex.(Ref(labelmap), eliminated_variables)..., batchdim]
54+
newnewxs = [get_slice(x, dimx, samples.samples[ixloc, :]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
55+
code = DynamicEinCode(newnewixs, newnewiy)
56+
probabilities = code(newnewxs...)
5157

5258
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
53-
for i in axes(samples.samples, 2)
54-
newxs = [get_slice(x, dimx, samples.samples[ixloc, i]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
55-
newy = get_element(y, slice_y_dim, samples.samples[iy_in_sample, i])
56-
probabilities = einsum(code, (newxs...,), size_dict) / newy
57-
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
58-
# update the samples
59+
for i=axes(samples.samples, 2)
60+
config = StatsBase.sample(totalset, Weights(vec(selectdim(probabilities, ndims(probabilities), i))))
61+
# update the samplesS
5962
samples.samples[eliminated_locs, i] .= config.I .- 1
6063
end
6164
return samples
6265
end
6366

6467
# type unstable
65-
function get_slice(x, dim, config)
66-
asarray(x[[i dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x)
68+
function get_slice(x::AbstractArray{T}, slicedim, configs::AbstractMatrix) where T
69+
outdim = setdiff(1:ndims(x), slicedim)
70+
res = similar(x, [size(x, d) for d in outdim]..., size(configs, 2))
71+
return get_slice!(res, x, outdim, slicedim, configs)
6772
end
68-
function get_element(x, dim, config)
69-
x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...]
73+
74+
function get_slice!(res, x::AbstractArray{T}, outdim, slicedim, configs::AbstractMatrix) where T
75+
xstrides = strides(x)
76+
@inbounds for ci in CartesianIndices(res)
77+
idx = 1
78+
# the output dimension part
79+
for (dim, k) in zip(outdim, ci.I)
80+
idx += (k-1) * xstrides[dim]
81+
end
82+
# the sliced part
83+
batchidx = ci.I[end]
84+
for (dim, k) in zip(slicedim, view(configs, :, batchidx))
85+
idx += k * xstrides[dim]
86+
end
87+
res[ci] = x[idx]
88+
end
89+
return res
7090
end
7191

7292
"""
@@ -112,6 +132,8 @@ function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size
112132
if !OMEinsum.isleaf(code)
113133
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
114134
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
115-
generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict))
135+
for (arg, sib) in zip(code.args, cache.siblings)
136+
generate_samples(arg, sib, samples, size_dict)
137+
end
116138
end
117139
end

0 commit comments

Comments
 (0)