Skip to content

Commit 4166973

Browse files
committed
update
1 parent 9986d3f commit 4166973

File tree

2 files changed

+62
-26
lines changed

2 files changed

+62
-26
lines changed

src/sampling.jl

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct Samples{L} <: AbstractVector{SubArray{Float64, 1, Matrix{Float64}, Tuple{
1414
labels::Vector{L}
1515
setmask::BitVector
1616
end
17-
function setmask!(samples::Samples, eliminated_variables)
17+
function set_eliminated!(samples::Samples, eliminated_variables)
1818
for var in eliminated_variables
1919
loc = findfirst(==(var), samples.labels)
2020
samples.setmask[loc] && error("varaible `$var` is already eliminated.")
@@ -25,45 +25,57 @@ end
2525
Base.getindex(s::Samples, i::Int) = view(s.samples, :, i)
2626
Base.length(s::Samples) = size(s.samples, 2)
2727
Base.size(s::Samples) = (size(s.samples, 2),)
28+
eliminated_variables(samples::Samples) = samples.labels[samples.setmask]
2829
idx4labels(totalset, labels)::Vector{Int} = map(v->findfirst(==(v), totalset), labels)
2930

3031
"""
3132
$(TYPEDSIGNATURES)
3233
3334
The backward process for sampling configurations.
3435
35-
* `ixs` and `xs` are labels and tensor data for input tensors,
36-
* `iy` and `y` are labels and tensor data for the output tensor,
36+
### Arguments
37+
* `code` is the contraction code in the current step,
38+
* `env` is the environment tensor,
3739
* `samples` is the samples generated for eliminated variables,
3840
* `size_dict` is a key-value map from tensor label to dimension size.
3941
"""
40-
function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
41-
eliminated_variables = setdiff(vcat(ixs...), iy)
42-
eliminated_locs = idx4labels(samples.labels, eliminated_variables)
43-
setmask!(samples, eliminated_variables)
42+
function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(y), @nospecialize(env), samples::Samples, size_dict)
43+
ixs, iy = getixsv(code), getiyv(code)
44+
el = setdiff(vcat(ixs...), iy)
45+
# get probability
46+
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
47+
probabilities = einsum(prob_code, (xs..., env), size_dict)
4448

45-
# the contraction code to get probability
46-
newixs = map(ix->setdiff(ix, iy), ixs)
47-
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
48-
slice_xs_dim = map(ix->idx4labels(ix, ix iy), ixs)
49+
# sample from the probability tensor
50+
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
51+
eliminated_locs = idx4labels(samples.labels, el)
52+
for i=axes(samples.samples, 2)
53+
config = StatsBase.sample(totalset, Weights(vec(selectdim(probabilities, ndims(probabilities), i))))
54+
samples.samples[eliminated_locs, i] .= config.I .- 1
55+
end
4956

50-
# relabel and compute probabilities
57+
# eliminate the sampled variables
58+
set_eliminated!(samples, el)
59+
for l in el
60+
size_dict[l] = 1
61+
end
62+
for sample in sampels
63+
map(x->eliminate_dimensions!(x, el=>sample), xs)
64+
end
65+
66+
# update environment
67+
for (i, ix) in enumerate(ixs)
68+
end
69+
return envs
70+
end
71+
72+
function addbatch(samples::Samples, eliminated_variables)
5173
uniquelabels = unique!(vcat(ixs..., iy))
5274
labelmap = Dict(zip(uniquelabels, 1:length(uniquelabels)))
5375
batchdim = length(labelmap) + 1
5476
newnewixs = [Int[getindex.(Ref(labelmap), ix)..., batchdim] for ix in newixs]
5577
newnewiy = Int[getindex.(Ref(labelmap), eliminated_variables)..., batchdim]
5678
newnewxs = [get_slice(x, dimx, samples.samples[ixloc, :]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
57-
code = DynamicEinCode(newnewixs, newnewiy)
58-
probabilities = code(newnewxs...)
59-
60-
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
61-
for i=axes(samples.samples, 2)
62-
config = StatsBase.sample(totalset, Weights(vec(selectdim(probabilities, ndims(probabilities), i))))
63-
# update the samplesS
64-
samples.samples[eliminated_locs, i] .= config.I .- 1
65-
end
66-
return samples
6779
end
6880

6981
# type unstable
@@ -137,12 +149,12 @@ function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_d
137149
end
138150
return generate_samples(se.eins, cache, samples, size_dict)
139151
end
140-
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
152+
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, size_dict::Dict) where {T}
141153
if !OMEinsum.isleaf(code)
142154
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
143-
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
144-
for (arg, sib) in zip(code.args, cache.siblings)
145-
generate_samples(arg, sib, samples, size_dict)
155+
envs = backward_sampling!(code.eins, xs, cache.content, env, samples, copy(size_dict))
156+
for (arg, sib, env) in zip(code.args, cache.siblings, envs)
157+
generate_samples(arg, sib, env, samples, size_dict)
146158
end
147159
end
148160
end

test/sampling.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TensorInference, Test
2+
using StatsBase: kldivergence
23

34
@testset "sampling" begin
45
model = TensorInference.read_model_from_string("""MARKOV
@@ -59,4 +60,27 @@ using TensorInference, Test
5960
mars = marginals(tnet)
6061
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
6162
@test isapprox([[mars[[i]][1] for i=1:6]..., mars[[8]][1]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
63+
end
64+
65+
@testset "sample MPS" begin
66+
tensors = [
67+
[rand(2, 2) for i=1:2],
68+
[rand(2, 2, 2) for i=1:2],
69+
[rand(2, 2, 2) for i=1:2],
70+
[rand(2, 2) for i=1:2],
71+
]
72+
ixs = [[1, 5], [5, 2, 6], [6, 3, 7], [7, 4], [1, 8], [8, 2, 9], [9, 3, 10], [10, 4]]
73+
mps = TensorNetworkModel(
74+
collect(1:10),
75+
DynamicEinCode(ixs, Int[]),
76+
[tensors..., conj.(tensors)...],
77+
Dict{Int, Int}(),
78+
collect(5:10)
79+
)
80+
samples = sample(mps, 1000)
81+
indices = samples.samples
82+
@show indices
83+
probs = vec(DynamicEinCode(ixs, collect(1:4))(tensors...))
84+
negative_loglikelyhood(samples, probs) = -sum(log.(probs[indices]))
85+
@test negative_loglikelyhood(samples, probs)
6286
end

0 commit comments

Comments
 (0)