Skip to content

Commit 63a6c8e

Browse files
committed
tensor inference
1 parent 1568a63 commit 63a6c8e

File tree

2 files changed

+83
-32
lines changed

2 files changed

+83
-32
lines changed

src/sampling.jl

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,44 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
2828
return view(samples.samples, idx, :)
2929
end
3030

31-
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
31+
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}, <:AbstractVector}) where {T, N, L}
3232
@assert length(ix) == N
33-
idx = ntuple(N) do i
34-
if ix[i] el.first
35-
k = el.second[findfirst(==(ix[i]), el.first)] + 1
33+
return x[eliminated_selector(size(x), ix, el.first, el.second)...]
34+
end
35+
function eliminated_size(size0, ix, labels)
36+
@assert length(size0) == length(ix)
37+
return ntuple(length(ix)) do i
38+
ix[i] labels ? 1 : size0[i]
39+
end
40+
end
41+
function eliminated_selector(size0, ix, labels, config)
42+
return ntuple(length(ix)) do i
43+
if ix[i] labels
44+
k = config[findfirst(==(ix[i]), labels)] + 1
3645
k:k
3746
else
38-
1:size(x, i)
47+
1:size0[i]
3948
end
4049
end
41-
return asarray(x[idx...], x)
50+
end
51+
function eliminate_dimensions_addbatch!(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}, <:AbstractMatrix}, batch_label::L) where {T, N, L}
52+
nbatch = size(el.second, 2)
53+
@assert length(ix) == N
54+
res = similar(x, (eliminated_size(size(x), ix, el.first)..., nbatch))
55+
for ibatch in 1:nbatch
56+
selectdim(res, N+1, ibatch) .= eliminate_dimensions(x, ix, el.first=>view(el.second, :, ibatch))
57+
end
58+
push!(ix, batch_label)
59+
return res
60+
end
61+
function eliminate_dimensions_withbatch(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}, <:AbstractMatrix}) where {T, N, L}
62+
nbatch = size(el.second, 2)
63+
@assert length(ix) == N && size(x, N) == nbatch
64+
res = similar(x, (eliminated_size(size(x), ix, el.first)))
65+
for ibatch in 1:nbatch
66+
selectdim(res, N, ibatch) .= eliminate_dimensions(selectdim(x, N, ibatch), ix[1:end-1], el.first=>view(el.second, :, ibatch))
67+
end
68+
return res
4269
end
4370

4471
"""
@@ -72,78 +99,96 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
7299
end
73100
samples = Samples(configs, queryvars)
74101
# back-propagate
75-
env = copy(cache.content)
102+
env = similar(cache.content, (size(cache.content)..., n)) # batched env
76103
fill!(env, one(eltype(env)))
77-
generate_samples!(tn.code, cache, env, samples, copy(samples.labels), size_dict) # note: `copy` is necessary
104+
batch_label = _newindex(OMEinsum.uniquelabels(tn.code))
105+
code = deepcopy(tn.code)
106+
iy_env = [OMEinsum.getiyv(code)..., batch_label]
107+
size_dict[batch_label] = n
108+
generate_samples!(code, cache, iy_env, env, samples, copy(samples.labels), batch_label, size_dict) # note: `copy` is necessary
78109
# set evidence variables
79-
for (k, v) in tn.evidence
80-
idx = findfirst(==(k), samples.labels)
110+
for (k, v) in setdiff(tn.evidence, queryvars)
111+
idx = findfirst(==(k), queryvars)
81112
samples.samples[idx, :] .= v
82113
end
83114
return samples
84115
end
116+
_newindex(labels::AbstractVector{<:Union{Int, Char}}) = maximum(labels) + 1
117+
_newindex(::AbstractVector{Symbol}) = gensym(:batch)
85118
_Weights(x::AbstractVector{<:Real}) = Weights(x)
86119
function _Weights(x::AbstractArray{<:Complex})
87-
@assert all(e->abs(imag(e)) < max(100*eps(abs(e)), 1e-10), x) "Complex probability encountered: $x"
120+
@assert all(e->abs(imag(e)) < max(100*eps(abs(e)), 1e-8), x) "Complex probability encountered: $x"
88121
return Weights(real.(x))
89122
end
90123

91-
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, pool, size_dict::Dict) where {T}
124+
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, iy_env::Vector{Int}, env::AbstractArray{T}, samples::Samples{L}, pool, batch_label::L, size_dict::Dict{L}) where {T, L}
92125
# slicing is not supported yet.
93126
if length(se.slicing) != 0
94127
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
95128
end
96-
return generate_samples!(se.eins, cache, env, samples, pool, size_dict)
129+
return generate_samples!(se.eins, cache, iy_env, env, samples, pool, batch_label, size_dict)
97130
end
98131

99132
# pool is a vector of labels that are not eliminated yet.
100-
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples{L}, pool::Vector{L}, size_dict::Dict{L}) where {T, L}
133+
function generate_samples!(code::DynamicNestedEinsum, cache::CacheTree{T}, iy_env::Vector{L}, env::AbstractArray{T}, samples::Samples{L}, pool::Vector{L}, batch_label::L, size_dict::Dict{L}) where {T, L}
134+
@assert length(iy_env) == ndims(env)
101135
if !(OMEinsum.isleaf(code))
102136
ixs, iy = getixsv(code.eins), getiyv(code.eins)
103137
for (subcode, child, ix) in zip(code.args, cache.children, ixs)
104138
# subenv for the current child, use it to sample and update its cache
105139
siblings = filter(x->x !== child, cache.children)
106140
siblings_ixs = filter(x->x !== ix, ixs)
107-
envcode = optimize_code(EinCode([siblings_ixs..., iy], ix), size_dict, GreedyMethod(; nrepeat=1))
141+
iy_subenv = batch_label ix ? ix : [ix..., batch_label]
142+
envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod(; nrepeat=1))
108143
subenv = einsum(envcode, (getfield.(siblings, :content)..., env), size_dict)
109144

110145
# generate samples
111146
sample_vars = ix pool
112147
if !isempty(sample_vars)
113-
probabilities = einsum(DynamicEinCode([ix, ix], sample_vars), (child.content, subenv), size_dict)
114-
update_samples!(samples, sample_vars, probabilities)
148+
probabilities = einsum(DynamicEinCode([ix, iy_subenv], [sample_vars..., batch_label]), (child.content, subenv), size_dict)
149+
for ibatch in axes(probabilities, ndims(probabilities))
150+
update_samples!(samples.labels, samples[ibatch], sample_vars, selectdim(probabilities, ndims(probabilities), ibatch))
151+
end
115152
setdiff!(pool, sample_vars)
116153

117154
# eliminate the sampled variables
118155
setindex!.(Ref(size_dict), 1, sample_vars)
119-
subsamples = subset(samples, sample_vars)[:, 1]
120-
udpate_cache_tree!(code, cache, sample_vars=>subsamples, size_dict)
121-
subenv = eliminate_dimensions(subenv, ix, sample_vars=>subsamples)
156+
subsamples = subset(samples, sample_vars)
157+
udpate_cache_tree!(code, cache, sample_vars=>subsamples, batch_label, size_dict)
158+
subenv = _eliminate!(subenv, ix, sample_vars=>subsamples, batch_label)
122159
end
123160

124161
# recurse
125-
generate_samples!(subcode, child, subenv, samples, pool, size_dict)
162+
generate_samples!(subcode, child, iy_subenv, subenv, samples, pool, batch_label, size_dict)
126163
end
127164
end
128165
end
129166

167+
function _eliminate!(x, ix, el, batch_label)
168+
if batch_label ix
169+
eliminate_dimensions_withbatch(x, ix, el)
170+
else
171+
eliminate_dimensions_addbatch!(x, ix, el, batch_label)
172+
end
173+
end
174+
130175
# probabilities is a tensor of probabilities for each variable in `vars`.
131-
function update_samples!(samples::Samples, vars::AbstractVector{L}, probabilities::AbstractArray{T, N}) where {L, T, N}
176+
function update_samples!(labels, sample, vars::AbstractVector{L}, probabilities::AbstractArray{T, N}) where {L, T, N}
132177
@assert length(vars) == N
133178
totalset = CartesianIndices(probabilities)
134-
eliminated_locs = idx4labels(samples.labels, vars)
179+
eliminated_locs = idx4labels(labels, vars)
135180
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
136-
samples.samples[eliminated_locs, 1] .= config.I .- 1
181+
sample[eliminated_locs] .= config.I .- 1
137182
end
138183

139-
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, size_dict::Dict{L}) where {T, L}
184+
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, batch_label::L, size_dict::Dict{L}) where {T, L}
140185
OMEinsum.isleaf(ne) && return
141186
updated = false
142187
for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins))
143188
if any(x->x el.first, ix)
144189
updated = true
145-
child.content = eliminate_dimensions(child.content, ix, el)
146-
udpate_cache_tree!(subcode, child, el, size_dict)
190+
child.content = _eliminate!(child.content, ix, el, batch_label)
191+
udpate_cache_tree!(subcode, child, el, batch_label, size_dict)
147192
end
148193
end
149194
updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict))

test/sampling.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using TensorInference, Test, LinearAlgebra
22
import StatsBase
3-
using OMEinsum
3+
using OMEinsum, Random
44

55
@testset "sampling" begin
66
model = TensorInference.read_model_from_string("""MARKOV
@@ -50,6 +50,7 @@ using OMEinsum
5050
""")
5151
n = 10000
5252
tnet = TensorNetworkModel(model)
53+
@show sample(tnet, 10)
5354
samples = sample(tnet, n)
5455
mars = marginals(tnet)
5556
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
@@ -67,16 +68,21 @@ end
6768
n = 4
6869
chi = 3
6970
mps = random_matrix_product_state(n, chi)
70-
num_samples = 1000
71-
sample(mps, num_samples; queryvars=vcat(mps.mars...))
71+
Random.seed!(134)
72+
num_samples = 10000
73+
# samples = map(1:num_samples) do i
74+
# sample(mps, 1; queryvars=vcat(mps.mars...)).samples[:,1]
75+
# end
76+
samples = sample(mps, num_samples; queryvars=vcat(mps.mars...))
7277
indices = map(samples) do sample
73-
sum(i->sample[i] * 2^(i-1), 1:4) + 1
78+
sum(i->sample[i] * 2^(i-1), 1:n) + 1
7479
end
75-
distribution = map(1:16) do i
80+
distribution = map(1:2^n) do i
7681
count(j->j==i, indices) / num_samples
7782
end
7883
probs = normalize!(real.(vec(DynamicEinCode(ixs, collect(1:4))(mps.tensors...))), 1)
7984
negative_loglikelyhood(probs, samples) = -sum(log.(probs[samples]))/length(samples)
8085
entropy(probs) = -sum(probs .* log.(probs))
86+
@show negative_loglikelyhood(probs, indices), entropy(probs)
8187
@test negative_loglikelyhood(probs, indices) entropy(probs) atol=1e-1
8288
end

0 commit comments

Comments
 (0)