Skip to content

Commit 8c3c7e5

Browse files
committed
update
1 parent 4166973 commit 8c3c7e5

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1313
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1414
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
15+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1516
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
@@ -25,6 +26,7 @@ LinearAlgebra = "1"
2526
OMEinsum = "0.8"
2627
Pkg = "1"
2728
PrecompileTools = "1"
29+
PrettyTables = "2"
2830
Requires = "1"
2931
StatsBase = "0.34"
3032
TropicalNumbers = "0.5.4, 0.6"

src/RescaledArray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ $(TYPEDSIGNATURES)
2323
Returns a rescaled array that equivalent to the input tensor.
2424
"""
2525
function rescale_array(tensor::AbstractArray{T})::RescaledArray where {T}
26-
maxf = maximum(tensor)
26+
maxf = maximum(abs, tensor)
2727
if iszero(maxf)
2828
@warn("The maximum value of the array to rescale is 0!")
2929
return RescaledArray(zero(T), tensor)
3030
end
31-
return RescaledArray(log(maxf), OMEinsum.asarray(tensor ./ maxf, tensor))
31+
return RescaledArray(T(log(maxf)), OMEinsum.asarray(tensor ./ maxf, tensor))
3232
end
3333

3434
for CT in [:DynamicEinCode, :StaticEinCode]

src/TensorInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using OMEinsum, LinearAlgebra
1111
using DocStringExtensions, TropicalNumbers
1212
# The Tropical GEMM support
1313
using StatsBase
14+
using PrettyTables
1415
import Pkg
1516

1617
# reexport OMEinsum functions

src/mar.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ end
1616
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
1717
# It is a tree structure that isomorphic to the contraction tree,
1818
# `content` is the cached intermediate contraction result.
19-
# `siblings` are the siblings of current node.
19+
# `children` are the children of current node, e.g. tensors that are contracted to get `content`.
2020
struct CacheTree{T}
2121
content::AbstractArray{T}
22-
siblings::Vector{CacheTree{T}}
22+
children::Vector{CacheTree{T}}
2323
end
2424

2525
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
@@ -62,7 +62,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
6262
if OMEinsum.isleaf(code)
6363
return CacheTree(dy, CacheTree{T}[])
6464
else
65-
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
65+
xs = ntuple(i -> cache.children[i].content, length(cache.children))
6666
# `einsum_grad` is the back-propagation rule for einsum function.
6767
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
6868
# Then the back-propagation pass is
@@ -73,7 +73,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
7373
# ```
7474
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
7575
dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy)
76-
return CacheTree(dy, generate_gradient_tree.(code.args, cache.siblings, dxs, Ref(size_dict)))
76+
return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict)))
7777
end
7878
end
7979

@@ -116,7 +116,7 @@ function extract_leaves!(code, cache, res)
116116
res[code.tensorindex] = cache.content
117117
else
118118
# resurse deeper
119-
extract_leaves!.(code.args, cache.siblings, Ref(res))
119+
extract_leaves!.(code.args, cache.children, Ref(res))
120120
end
121121
return res
122122
end

src/sampling.jl

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,22 @@ end
2525
Base.getindex(s::Samples, i::Int) = view(s.samples, :, i)
2626
Base.length(s::Samples) = size(s.samples, 2)
2727
Base.size(s::Samples) = (size(s.samples, 2),)
28+
function Base.show(io::IO, s::Samples) # display with PrettyTables
29+
println(io, typeof(s))
30+
PrettyTables.pretty_table(io, s.samples', header=s.labels)
31+
end
32+
num_samples(samples::Samples) = size(samples.samples, 2)
2833
eliminated_variables(samples::Samples) = samples.labels[samples.setmask]
29-
idx4labels(totalset, labels)::Vector{Int} = map(v->findfirst(==(v), totalset), labels)
34+
is_eliminated(samples::Samples{L}, var::L) where L = samples.setmask[findfirst(==(var), samples.labels)]
35+
function idx4labels(totalset::AbstractVector{L}, labels::AbstractVector{L})::Vector{Int} where L
36+
map(v->findfirst(==(v), totalset), labels)
37+
end
38+
idx4labels(samples::Samples{L}, lb::L) where L = findfirst(==(lb), samples.labels)
39+
function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
40+
idx = idx4labels(samples.labels, labels)
41+
@assert all(i->samples.setmask[i], idx)
42+
return samples.samples[idx, :]
43+
end
3044

3145
"""
3246
$(TYPEDSIGNATURES)
@@ -39,34 +53,49 @@ The backward process for sampling configurations.
3953
* `samples` is the samples generated for eliminated variables,
4054
* `size_dict` is a key-value map from tensor label to dimension size.
4155
"""
42-
function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(y), @nospecialize(env), samples::Samples, size_dict)
56+
function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(env), samples::Samples, size_dict)
4357
ixs, iy = getixsv(code), getiyv(code)
44-
el = setdiff(vcat(ixs...), iy)
58+
el = setdiff(vcat(ixs...), iy) samples.labels
59+
4560
# get probability
4661
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
62+
el_prev = eliminated_variables(samples)
63+
xs = [eliminate_dimensions(x, ix, el_prev=>subset(samples, el_prev)[:,1]) for (ix, x) in zip(ixs, xs)]
4764
probabilities = einsum(prob_code, (xs..., env), size_dict)
4865

4966
# sample from the probability tensor
5067
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
5168
eliminated_locs = idx4labels(samples.labels, el)
52-
for i=axes(samples.samples, 2)
53-
config = StatsBase.sample(totalset, Weights(vec(selectdim(probabilities, ndims(probabilities), i))))
54-
samples.samples[eliminated_locs, i] .= config.I .- 1
55-
end
69+
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
70+
samples.samples[eliminated_locs, 1] .= config.I .- 1
5671

5772
# eliminate the sampled variables
5873
set_eliminated!(samples, el)
5974
for l in el
6075
size_dict[l] = 1
6176
end
62-
for sample in sampels
63-
map(x->eliminate_dimensions!(x, el=>sample), xs)
64-
end
77+
sub = subset(samples, el)[:, 1]
78+
xs = [eliminate_dimensions(x, ix, el=>sub) for (ix, x) in zip(ixs, xs)]
79+
env = eliminate_dimensions(env, iy, el=>sub)
6580

6681
# update environment
67-
for (i, ix) in enumerate(ixs)
82+
return map(1:length(ixs)) do i
83+
rest = setdiff(1:length(ixs), i)
84+
code = optimize_code(EinCode([ixs[rest]..., iy], ixs[i]), size_dict, GreedyMethod(; nrepeat=1))
85+
einsum(code, (xs[rest]..., env), size_dict)
6886
end
69-
return envs
87+
end
88+
89+
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
90+
idx = ntuple(N) do i
91+
if ix[i] el.first
92+
k = el.second[findfirst(==(ix[i]), el.first)] + 1
93+
k:k
94+
else
95+
1:size(x, i)
96+
end
97+
end
98+
return asarray(x[idx...], x)
7099
end
71100

72101
function addbatch(samples::Samples, eliminated_variables)
@@ -113,48 +142,54 @@ Returns a vector of vector, each element being a configurations defined on `get_
113142
* `tn` is the tensor network model.
114143
* `n` is the number of samples to be returned.
115144
"""
116-
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Samples
145+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn))::Samples
117146
# generate tropical tensors with its elements being log(p).
118147
xs = adapt_tensors(tn; usecuda, rescale = false)
119148
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
120149
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
121150
# forward compute and cache intermediate results.
122151
cache = cached_einsum(tn.code, xs, size_dict)
123152
# initialize `y̅` as the initial batch of samples.
124-
labels = get_vars(tn)
125153
iy = getiyv(tn.code)
126-
setmask = falses(length(labels))
127-
idx = map(l->findfirst(==(l), labels), iy)
154+
setmask = falses(length(queryvars))
155+
idx = map(l->findfirst(==(l), queryvars), iy queryvars)
128156
setmask[idx] .= true
129-
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
130-
configs = zeros(Int, length(labels), n)
157+
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _Weights(vec(cache.content)), n)
158+
configs = zeros(Int, length(queryvars), n)
131159
for i=1:n
132160
configs[idx, i] .= indices[i].I .- 1
133161
end
134-
samples = Samples(configs, labels, setmask)
162+
samples = Samples(configs, queryvars, setmask)
135163
# back-propagate
136-
generate_samples(tn.code, cache, samples, size_dict)
164+
env = copy(cache.content)
165+
fill!(env, one(eltype(env)))
166+
generate_samples!(tn.code, cache, env, samples, size_dict)
137167
# set evidence variables
138168
for (k, v) in tn.evidence
139-
idx = findfirst(==(k), labels)
169+
idx = findfirst(==(k), samples.labels)
140170
samples.samples[idx, :] .= v
141171
end
142172
return samples
143173
end
174+
_Weights(x::AbstractVector{<:Real}) = Weights(x)
175+
function _Weights(x::AbstractArray{<:Complex})
176+
@assert all(e->abs(imag(e)) < 100*eps(abs(e)), x)
177+
return Weights(real.(x))
178+
end
144179

145-
function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
180+
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, size_dict::Dict) where {T}
146181
# slicing is not supported yet.
147182
if length(se.slicing) != 0
148183
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
149184
end
150-
return generate_samples(se.eins, cache, samples, size_dict)
185+
return generate_samples!(se.eins, cache, env, samples, size_dict)
151186
end
152-
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, size_dict::Dict) where {T}
187+
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples, size_dict::Dict) where {T}
153188
if !OMEinsum.isleaf(code)
154-
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
155-
envs = backward_sampling!(code.eins, xs, cache.content, env, samples, copy(size_dict))
156-
for (arg, sib, env) in zip(code.args, cache.siblings, envs)
157-
generate_samples(arg, sib, env, samples, size_dict)
189+
xs = ntuple(i -> cache.children[i].content, length(cache.children))
190+
envs = backward_sampling!(code.eins, xs, env, samples, size_dict)
191+
for (arg, sib, env) in zip(code.args, cache.children, envs)
192+
generate_samples!(arg, sib, env, samples, size_dict)
158193
end
159194
end
160195
end

test/sampling.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using TensorInference, Test
22
using StatsBase: kldivergence
3+
using OMEinsum
34

45
@testset "sampling" begin
56
model = TensorInference.read_model_from_string("""MARKOV
@@ -64,21 +65,26 @@ end
6465

6566
@testset "sample MPS" begin
6667
tensors = [
67-
[rand(2, 2) for i=1:2],
68-
[rand(2, 2, 2) for i=1:2],
69-
[rand(2, 2, 2) for i=1:2],
70-
[rand(2, 2) for i=1:2],
68+
randn(ComplexF64, 2, 3),
69+
randn(ComplexF64, 3, 2, 3),
70+
randn(ComplexF64, 3, 2, 3),
71+
randn(ComplexF64, 3, 2),
7172
]
73+
tensors = [tensors..., conj.(tensors)...]
7274
ixs = [[1, 5], [5, 2, 6], [6, 3, 7], [7, 4], [1, 8], [8, 2, 9], [9, 3, 10], [10, 4]]
7375
mps = TensorNetworkModel(
7476
collect(1:10),
75-
DynamicEinCode(ixs, Int[]),
76-
[tensors..., conj.(tensors)...],
77+
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
78+
tensors,
7779
Dict{Int, Int}(),
78-
collect(5:10)
80+
[[i] for i=5:10]
7981
)
80-
samples = sample(mps, 1000)
81-
indices = samples.samples
82+
num_samples = 1
83+
samples = sample(mps, num_samples; queryvars=[1, 2, 3, 4])
84+
indices = map(samples) do sample
85+
sum(i->sample[i] * 2^(i-1), 1:4) + 1
86+
end
87+
@show samples
8288
@show indices
8389
probs = vec(DynamicEinCode(ixs, collect(1:4))(tensors...))
8490
negative_loglikelyhood(samples, probs) = -sum(log.(probs[indices]))

0 commit comments

Comments
 (0)