|
| 1 | +############ Sampling ############ |
| 2 | +""" |
| 3 | +$TYPEDEF |
| 4 | +
|
| 5 | +### Fields |
| 6 | +$TYPEDFIELDS |
| 7 | +
|
| 8 | +The sampled configurations are stored in `samples`, which is a vector of vector. |
| 9 | +`labels` is a vector of variable names for labeling configurations. |
| 10 | +The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete. |
| 11 | +""" |
| 12 | +struct Samples{L} |
| 13 | + samples::Vector{Vector{Int}} |
| 14 | + labels::Vector{L} |
| 15 | + setmask::BitVector |
| 16 | +end |
| 17 | +function setmask!(samples::Samples, eliminated_variables) |
| 18 | + for var in eliminated_variables |
| 19 | + loc = findfirst(==(var), samples.labels) |
| 20 | + samples.setmask[loc] && error("varaible `$var` is already eliminated.") |
| 21 | + samples.setmask[loc] = true |
| 22 | + end |
| 23 | + return samples |
| 24 | +end |
| 25 | + |
| 26 | +idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels) |
| 27 | + |
| 28 | +""" |
| 29 | +$(TYPEDSIGNATURES) |
| 30 | +
|
| 31 | +The backward process for sampling configurations. |
| 32 | +
|
| 33 | +* `ixs` and `xs` are labels and tensor data for input tensors, |
| 34 | +* `iy` and `y` are labels and tensor data for the output tensor, |
| 35 | +* `samples` is the samples generated for eliminated variables, |
| 36 | +* `size_dict` is a key-value map from tensor label to dimension size. |
| 37 | +""" |
| 38 | +function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict) |
| 39 | + eliminated_variables = setdiff(vcat(ixs...), iy) |
| 40 | + eliminated_locs = idx4labels(samples.labels, eliminated_variables) |
| 41 | + setmask!(samples, eliminated_variables) |
| 42 | + |
| 43 | + # the contraction code to get probability |
| 44 | + newiy = eliminated_variables |
| 45 | + iy_in_sample = idx4labels(samples.labels, iy) |
| 46 | + slice_y_dim = collect(1:length(iy)) |
| 47 | + newixs = map(ix->setdiff(ix, iy), ixs) |
| 48 | + ix_in_sample = map(ix->idx4labels(samples.labels, ix ∩ iy), ixs) |
| 49 | + slice_xs_dim = map(ix->idx4labels(ix, ix ∩ iy), ixs) |
| 50 | + code = DynamicEinCode(newixs, newiy) |
| 51 | + |
| 52 | + totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,)) |
| 53 | + for (i, sample) in enumerate(samples.samples) |
| 54 | + newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)] |
| 55 | + newy = get_element(y, slice_y_dim, sample[iy_in_sample]) |
| 56 | + probabilities = einsum(code, (newxs...,), size_dict) / newy |
| 57 | + config = StatsBase.sample(totalset, Weights(vec(probabilities))) |
| 58 | + # update the samples |
| 59 | + samples.samples[i][eliminated_locs] .= config.I .- 1 |
| 60 | + end |
| 61 | + return samples |
| 62 | +end |
| 63 | + |
| 64 | +# type unstable |
| 65 | +function get_slice(x, dim, config) |
| 66 | + asarray(x[[i ∈ dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x) |
| 67 | +end |
| 68 | +function get_element(x, dim, config) |
| 69 | + x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...] |
| 70 | +end |
| 71 | + |
| 72 | +""" |
| 73 | +$(TYPEDSIGNATURES) |
| 74 | +
|
| 75 | +Generate samples from a tensor network based probabilistic model. |
| 76 | +Returns a vector of vector, each element being a configurations defined on `get_vars(tn)`. |
| 77 | +
|
| 78 | +### Arguments |
| 79 | +* `tn` is the tensor network model. |
| 80 | +* `n` is the number of samples to be returned. |
| 81 | +""" |
| 82 | +function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{Int}} |
| 83 | + # generate tropical tensors with its elements being log(p). |
| 84 | + xs = adapt_tensors(tn; usecuda, rescale = false) |
| 85 | + # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary. |
| 86 | + size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}()) |
| 87 | + # forward compute and cache intermediate results. |
| 88 | + cache = cached_einsum(tn.code, xs, size_dict) |
| 89 | + # initialize `y̅` as the initial batch of samples. |
| 90 | + labels = get_vars(tn) |
| 91 | + iy = getiyv(tn.code) |
| 92 | + setmask = falses(length(labels)) |
| 93 | + idx = map(l->findfirst(==(l), labels), iy) |
| 94 | + setmask[idx] .= true |
| 95 | + indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n) |
| 96 | + configs = map(indices) do ind |
| 97 | + c=zeros(Int, length(labels)) |
| 98 | + c[idx] .= ind.I .- 1 |
| 99 | + c |
| 100 | + end |
| 101 | + samples = Samples(configs, labels, setmask) |
| 102 | + # back-propagate |
| 103 | + generate_samples(tn.code, cache, samples, size_dict) |
| 104 | + return samples.samples |
| 105 | +end |
| 106 | + |
| 107 | +function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T} |
| 108 | + if !OMEinsum.isleaf(code) |
| 109 | + xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings)) |
| 110 | + backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict) |
| 111 | + generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict)) |
| 112 | + end |
| 113 | +end |
0 commit comments