Skip to content

Commit c105012

Browse files
committed
update matrix product state
1 parent cb59646 commit c105012

File tree

4 files changed

+61
-20
lines changed

4 files changed

+61
-20
lines changed

src/TensorInference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ export sample
3535
# MMAP
3636
export MMAPModel
3737

38+
# utils
39+
export matrix_product_state
40+
3841
include("Core.jl")
3942
include("RescaledArray.jl")
4043
include("utils.jl")

src/sampling.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
7070
# back-propagate
7171
env = copy(cache.content)
7272
fill!(env, one(eltype(env)))
73-
generate_samples!(tn.code, cache, env, samples, samples.labels, size_dict)
73+
generate_samples!(tn.code, cache, env, samples, copy(samples.labels), size_dict) # note: `copy` is necessary
7474
# set evidence variables
7575
for (k, v) in tn.evidence
7676
idx = findfirst(==(k), samples.labels)
@@ -80,7 +80,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
8080
end
8181
_Weights(x::AbstractVector{<:Real}) = Weights(x)
8282
function _Weights(x::AbstractArray{<:Complex})
83-
@assert all(e->abs(imag(e)) < 100*eps(abs(e)), x)
83+
@assert all(e->abs(imag(e)) < 100*eps(abs(e)), x) "Complex probability encountered: $x"
8484
return Weights(real.(x))
8585
end
8686

@@ -118,7 +118,7 @@ function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::Abstrac
118118
end
119119

120120
# recurse
121-
generate_samples!(subcode, child, subenv, samples, setdiff(pool, sample_vars), size_dict)
121+
generate_samples!(subcode, child, subenv, samples, pool, size_dict)
122122
end
123123
end
124124
end

src/utils.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,53 @@ function get_artifact_path(artifact_name::String)
305305
artifact_hash = Pkg.Artifacts.artifact_hash(artifact_name, artifact_toml)
306306
return Pkg.Artifacts.artifact_path(artifact_hash)
307307
end
308+
309+
"""
310+
$TYPEDSIGNATURES
311+
312+
Matrix product state (MPS) is a tensor network model that is widely used in
313+
quantum many-body physics. It is a special case of tensor network model where
314+
the tensors are rank-3 tensors and the physical indices are connected in a
315+
chain. The MPS is defined as:
316+
317+
```math
318+
\\begin{align*}
319+
\\left| \\psi \\right\\rangle &= \\sum_{x_1, x_2, \\ldots, x_n} \\text{Tr}(A_1^{x_1} A_2^{x_2} \\cdots A_n^{x_n}) \\left| x_1, x_2, \\ldots, x_n \\right\\rangle \\\\
320+
\\left\\langle \\psi \\right| &= \\sum_{x_1, x_2, \\ldots, x_n} \\text{Tr}(A_n^{x_n} \\cdots A_2^{x_2} A_1^{x_1}) \\left\\langle x_1, x_2, \\ldots, x_n \\right|
321+
\\end{align*}
322+
```
323+
324+
where \$A_i^{x_i}\$ is a rank-3 tensor with physical index \$x_i\$ and two virtual
325+
indices connecting to the next tensor. The MPS is a special case of the tensor
326+
network model where the tensors are rank-3 tensors and the physical indices are
327+
connected in a chain.
328+
329+
### Arguments
330+
- `n` is the number of physical indices.
331+
- `chi` is the bond dimension of the virtual indices.
332+
- `d` is the dimension of the physical indices.
333+
"""
334+
function matrix_product_state(n::Int, chi::Int, d::Int=2)
335+
tensors = Any[randn(ComplexF64, d, chi)]
336+
physical_indices = collect(1:n)
337+
virtual_indices_ket = collect(n+1:2n-1)
338+
virtual_indices_bra = collect(2n:3n-2)
339+
ixs_ket = [[physical_indices[1], virtual_indices_ket[1]]]
340+
ixs_bra = [[physical_indices[1], virtual_indices_bra[1]]]
341+
for i = 2:n-1
342+
push!(tensors, randn(ComplexF64, chi, d, chi))
343+
push!(ixs_ket, [virtual_indices_ket[i-1], physical_indices[i], virtual_indices_ket[i]])
344+
push!(ixs_bra, [virtual_indices_bra[i-1], physical_indices[i], virtual_indices_bra[i]])
345+
end
346+
push!(tensors, randn(ComplexF64, chi, d))
347+
push!(ixs_ket, [virtual_indices_ket[n-1], physical_indices[n]])
348+
push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]])
349+
tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...]
350+
return TensorNetworkModel(
351+
collect(1:3n-2),
352+
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
353+
tensors,
354+
Dict{Int, Int}(),
355+
Vector{Int}[[i] for i=1:n]
356+
)
357+
end

test/sampling.jl

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,20 @@ using OMEinsum
6464
end
6565

6666
@testset "sample MPS" begin
67-
tensors = [
68-
randn(ComplexF64, 2, 3),
69-
randn(ComplexF64, 3, 2, 3),
70-
randn(ComplexF64, 3, 2, 3),
71-
randn(ComplexF64, 3, 2),
72-
]
73-
tensors = [tensors..., conj.(tensors)...]
74-
ixs = [[1, 5], [5, 2, 6], [6, 3, 7], [7, 4], [1, 8], [8, 2, 9], [9, 3, 10], [10, 4]]
75-
mps = TensorNetworkModel(
76-
collect(1:10),
77-
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
78-
tensors,
79-
Dict{Int, Int}(),
80-
Vector{Int}[]
81-
)
67+
n = 4
68+
chi = 3
69+
mps = matrix_product_state(n, chi)
8270
num_samples = 1000
8371
samples = map(1:num_samples) do i
84-
sample(mps, 1; queryvars=[1, 2, 3, 4]).samples[:, 1]
72+
sample(mps, 1; queryvars=vcat(mps.mars...)).samples[:, 1]
8573
end
8674
indices = map(samples) do sample
8775
sum(i->sample[i] * 2^(i-1), 1:4) + 1
8876
end
8977
distribution = map(1:16) do i
9078
count(j->j==i, indices) / num_samples
9179
end
92-
probs = normalize!(real.(vec(DynamicEinCode(ixs, collect(1:4))(tensors...))), 1)
80+
probs = normalize!(real.(vec(DynamicEinCode(ixs, collect(1:4))(mps.tensors...))), 1)
9381
negative_loglikelyhood(probs, samples) = -sum(log.(probs[samples]))/length(samples)
9482
entropy(probs) = -sum(probs .* log.(probs))
9583
@test negative_loglikelyhood(probs, indices) entropy(probs) atol=1e-1

0 commit comments

Comments
 (0)