@@ -43,24 +43,29 @@ struct BPState{T, VT<:AbstractVector{T}}
4343end
4444
4545# message_in -> message_out
46- function process_message! (bp:: BPState )
46+ function process_message! (bp:: BPState ; normalize, damping )
4747 for (ov, iv) in zip (bp. message_out, bp. message_in)
48- _process_message! (ov, iv)
48+ _process_message! (ov, iv, normalize, damping )
4949 end
5050end
51- function _process_message! (ov:: Vector , iv:: Vector )
51+ function _process_message! (ov:: Vector , iv:: Vector , normalize :: Bool , damping )
5252 # process the message, TODO : speed up if needed!
5353 for (i, v) in enumerate (ov)
54- fill! (v, one (eltype (v))) # clear the output vector
54+ w = similar (v)
55+ fill! (w, one (eltype (v))) # clear the output vector
5556 for (j, u) in enumerate (iv)
56- j != i && (v .*= u)
57+ j != i && (w .*= u)
5758 end
59+ normalize && normalize! (w, 1 )
60+ v .= v .* damping + (1 - damping) * w
5861 end
5962end
6063
61- function collect_message! (bp:: BeliefPropgation , state:: BPState )
64+ function collect_message! (bp:: BeliefPropgation , state:: BPState ; normalize :: Bool )
6265 for it in 1 : num_tensors (bp)
63- _collect_message! (vectors_on_tensor (state. message_in, bp, it), bp. tensors[it], vectors_on_tensor (state. message_out, bp, it))
66+ out = vectors_on_tensor (state. message_in, bp, it)
67+ _collect_message! (out, bp. tensors[it], vectors_on_tensor (state. message_out, bp, it))
68+ normalize && normalize! .(out, 1 )
6469 end
6570end
6671# collect the vectors associated with the target tensor
@@ -78,7 +83,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
7883 for (o, g) in zip (vectors_out, gradient[2 : end ])
7984 o .= g
8085 end
81- return cost
86+ return cost[]
8287end
8388
8489# star code: contract a tensor with multiple vectors, one for each dimension
@@ -112,20 +117,20 @@ Run the belief propagation algorithm, and return the final state and the informa
112117- `max_iter::Int=100`: the maximum number of iterations
113118- `tol::Float64=1e-6`: the tolerance for the convergence
114119"""
115- function belief_propagate (bp:: BeliefPropgation ; max_iter :: Int = 100 , tol :: Float64 = 1e-6 )
120+ function belief_propagate (bp:: BeliefPropgation ; kwargs ... )
116121 state = initial_state (bp)
117- info = belief_propagate! (bp, state; max_iter = max_iter, tol = tol )
122+ info = belief_propagate! (bp, state; kwargs ... )
118123 return state, info
119124end
120125struct BPInfo
121126 converged:: Bool
122127 iterations:: Int
123128end
124- function belief_propagate! (bp:: BeliefPropgation , state:: BPState{T} ; max_iter:: Int = 100 , tol:: Float64 = 1e-6 ) where T
129+ function belief_propagate! (bp:: BeliefPropgation , state:: BPState{T} ; max_iter:: Int = 100 , tol= 1e-6 , damping = 0.2 ) where T
125130 pre_message_in = deepcopy (state. message_in)
126131 for i in 1 : max_iter
127- collect_message! (bp, state)
128- process_message! (state)
132+ collect_message! (bp, state; normalize = true )
133+ process_message! (state; normalize = true , damping = damping )
129134 # check convergence
130135 if all (iv -> all (it -> isapprox (state. message_in[iv][it], pre_message_in[iv][it], atol= tol), 1 : length (bp. v2t[iv])), 1 : num_variables (bp))
131136 return BPInfo (true , i)
0 commit comments