Skip to content

Commit d2262ff

Browse files
committed
update
1 parent 2c41d54 commit d2262ff

File tree

4 files changed

+56
-30
lines changed

4 files changed

+56
-30
lines changed

src/belief.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,29 @@ struct BPState{T, VT<:AbstractVector{T}}
4343
end
4444

4545
# message_in -> message_out
46-
function process_message!(bp::BPState)
46+
function process_message!(bp::BPState; normalize, damping)
4747
for (ov, iv) in zip(bp.message_out, bp.message_in)
48-
_process_message!(ov, iv)
48+
_process_message!(ov, iv, normalize, damping)
4949
end
5050
end
51-
function _process_message!(ov::Vector, iv::Vector)
51+
function _process_message!(ov::Vector, iv::Vector, normalize::Bool, damping)
5252
# process the message, TODO: speed up if needed!
5353
for (i, v) in enumerate(ov)
54-
fill!(v, one(eltype(v))) # clear the output vector
54+
w = similar(v)
55+
fill!(w, one(eltype(v))) # clear the output vector
5556
for (j, u) in enumerate(iv)
56-
j != i && (v .*= u)
57+
j != i && (w .*= u)
5758
end
59+
normalize && normalize!(w, 1)
60+
v .= v .* damping + (1 - damping) * w
5861
end
5962
end
6063

61-
function collect_message!(bp::BeliefPropgation, state::BPState)
64+
function collect_message!(bp::BeliefPropgation, state::BPState; normalize::Bool)
6265
for it in 1:num_tensors(bp)
63-
_collect_message!(vectors_on_tensor(state.message_in, bp, it), bp.tensors[it], vectors_on_tensor(state.message_out, bp, it))
66+
out = vectors_on_tensor(state.message_in, bp, it)
67+
_collect_message!(out, bp.tensors[it], vectors_on_tensor(state.message_out, bp, it))
68+
normalize && normalize!.(out, 1)
6469
end
6570
end
6671
# collect the vectors associated with the target tensor
@@ -78,7 +83,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
7883
for (o, g) in zip(vectors_out, gradient[2:end])
7984
o .= g
8085
end
81-
return cost
86+
return cost[]
8287
end
8388

8489
# star code: contract a tensor with multiple vectors, one for each dimension
@@ -112,20 +117,20 @@ Run the belief propagation algorithm, and return the final state and the informa
112117
- `max_iter::Int=100`: the maximum number of iterations
113118
- `tol::Float64=1e-6`: the tolerance for the convergence
114119
"""
115-
function belief_propagate(bp::BeliefPropgation; max_iter::Int=100, tol::Float64=1e-6)
120+
function belief_propagate(bp::BeliefPropgation; kwargs...)
116121
state = initial_state(bp)
117-
info = belief_propagate!(bp, state; max_iter=max_iter, tol=tol)
122+
info = belief_propagate!(bp, state; kwargs...)
118123
return state, info
119124
end
120125
struct BPInfo
121126
converged::Bool
122127
iterations::Int
123128
end
124-
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T
129+
function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol=1e-6, damping=0.2) where T
125130
pre_message_in = deepcopy(state.message_in)
126131
for i in 1:max_iter
127-
collect_message!(bp, state)
128-
process_message!(state)
132+
collect_message!(bp, state; normalize=true)
133+
process_message!(state; normalize=true, damping=damping)
129134
# check convergence
130135
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
131136
return BPInfo(true, i)

src/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,21 +380,21 @@ Tensor train (TT) is a tensor network model that is widely used in quantum
380380
many-body physics. This model is different from the matrix product state (MPS)
381381
in that it does not have an extra copy for representing the bra state.
382382
"""
383-
function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T
383+
function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2; periodic=false) where T
384384
# chi ^ (n-1) * (variance^n)^2 == 1/d^n
385385
variance = d^(-1/2) * chi^(-1/2+1/2n)
386-
tensors = Any[randn(T, d, chi) .* variance]
387386
physical_indices = collect(1:n)
388-
virtual_indices = collect(n+1:2n-1)
389-
ixs = [[physical_indices[1], virtual_indices[1]]]
387+
virtual_indices = collect(n+1:2n)
388+
tensors = Any[(periodic ? rand(T, chi, d, chi) : rand(T, d, chi)) .* variance]
389+
ixs = [periodic ? [virtual_indices[n], physical_indices[1], virtual_indices[1]] : [physical_indices[1], virtual_indices[1]]]
390390
for i = 2:n-1
391-
push!(tensors, randn(T, chi, d, chi) .* variance)
391+
push!(tensors, rand(T, chi, d, chi) .* variance)
392392
push!(ixs, [virtual_indices[i-1], physical_indices[i], virtual_indices[i]])
393393
end
394-
push!(tensors, randn(T, chi, d) .* variance)
395-
push!(ixs, [virtual_indices[n-1], physical_indices[n]])
394+
push!(tensors, (periodic ? rand(T, chi, d, chi) : rand(T, chi, d)) .* variance)
395+
push!(ixs, periodic ? [virtual_indices[n-1], physical_indices[n], virtual_indices[n]] : [virtual_indices[n-1], physical_indices[n]])
396396
size_dict = OMEinsum.get_size_dict(ixs, tensors)
397-
nvars = 2n-1
397+
nvars = periodic ? 2n : 2n-1
398398
return UAIModel(
399399
nvars,
400400
[size_dict[i] for i=1:nvars],

test/belief.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
using TensorInference, Test
2-
using OMEinsum
2+
using OMEinsum, LinearAlgebra
33

44
@testset "process message" begin
5-
mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]]
6-
mo_expected = [[6, 12, 20], [3, 8, 15], [2, 6, 12]]
5+
mi = [[1.0, 2, 3], [2.0, 3, 4], [3.0, 4, 5]]
6+
mo_expected = [[6.0, 12, 20], [3.0, 8, 15], [2.0, 6, 12]]
77
mo = similar.(mi)
8-
TensorInference._process_message!(mo, mi)
9-
@test mo == mo_expected
8+
TensorInference._process_message!(mo, mi, false, 0)
9+
@test all(mo .≈ mo_expected)
10+
11+
TensorInference._process_message!(mo, mi, true, 0)
12+
@test all(mo .≈ normalize!.(mo_expected, 1))
1013
end
1114

1215
@testset "star code" begin
@@ -44,14 +47,29 @@ end
4447
@test TensorInference.initial_state(bp) isa TensorInference.BPState
4548
state, info = belief_propagate(bp)
4649
@test info.converged
47-
@test info.iterations < 10
50+
@test info.iterations < 20
51+
mars = marginals(state)
52+
tnet = TensorNetworkModel(mps_uai)
53+
mars_tnet = marginals(tnet)
54+
for v in 1:TensorInference.num_variables(bp)
55+
@test mars[[v]] mars_tnet[[v]] atol=1e-6
56+
end
57+
end
58+
59+
@testset "belief propagation on circle" begin
60+
n = 10
61+
chi = 3
62+
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
63+
bp = BeliefPropgation(mps_uai)
64+
@test TensorInference.initial_state(bp) isa TensorInference.BPState
65+
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
66+
@test info.converged
67+
@test info.iterations < 100
4868
contraction_res = TensorInference.contraction_results(state)
4969
tnet = TensorNetworkModel(mps_uai)
50-
expected_result = probability(tnet)[]
51-
@test all(r -> isapprox(r, expected_result), contraction_res)
5270
mars = marginals(state)
5371
mars_tnet = marginals(tnet)
5472
for v in 1:TensorInference.num_variables(bp)
55-
@test mars[[v]] mars_tnet[[v]]
73+
@test mars[[v]] mars_tnet[[v]] atol=1e-4
5674
end
5775
end

test/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using TensorInference, Test
33
@testset "tensor train" begin
44
tt = random_tensor_train_uai(Float64, 5, 3)
55
@test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...)))
6+
7+
tt = random_tensor_train_uai(Float64, 5, 3; periodic=true)
8+
@test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...)))
69
end
710

811
@testset "mps" begin

0 commit comments

Comments
 (0)