Skip to content
Merged
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
version = "0.5.0"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProblemReductions = "899c297d-f7d2-4ebf-8815-a35996def416"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -22,15 +20,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
TensorInferenceCUDAExt = "CUDA"

[compat]
Artifacts = "1"
CUDA = "4, 5"
DocStringExtensions = "0.8.6, 0.9"
LinearAlgebra = "1"
OMEinsum = "0.8"
OMEinsum = "0.8.7"
Pkg = "1"
PrecompileTools = "1"
PrettyTables = "2"
ProblemReductions = "0.3"
StatsBase = "0.34"
TropicalNumbers = "0.5.4, 0.6"
julia = "1.9"
julia = "1.10"
4 changes: 4 additions & 0 deletions docs/src/api/public.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ RescaledArray
TensorNetworkModel
ArtifactProblemSpec
UAIModel
BeliefPropgation
```

## Functions
Expand All @@ -56,6 +57,7 @@ marginals
maximum_logp
most_probable_config
probability
belief_propagate
dataset_from_artifact
problem_from_artifact
read_model
Expand All @@ -69,4 +71,6 @@ sample
update_evidence!
update_temperature
random_matrix_product_state
random_matrix_product_uai
random_tensor_train_uai
```
19 changes: 19 additions & 0 deletions docs/src/tensor-networks.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ Some of these have been implemented in the
[OMEinsum](https://github.com/under-Peter/OMEinsum.jl) package. Please check
[Performance Tips](@ref) for more details.

## Belief propagation

Belief propagation[^Yedidia2003] is a message passing algorithm that can be used to compute the marginals of a probabilistic graphical model. It has close connections with the tensor networks. It can be viewed as a way to gauge the tensor networks[^Tindall2023], and can be combined with tensor networks to achieve better performance[^Wang2024].

Belief propagation is an approximate method, and the quality of the approximation can be improved by the loop series expansion[^Evenbly2024].


## References

[^Orus2014]:
Expand All @@ -227,3 +234,15 @@ Some of these have been implemented in the

[^Liu2023]:
Liu J G, Gao X, Cain M, et al. Computing solution space properties of combinatorial optimization problems via generic tensor networks[J]. SIAM Journal on Scientific Computing, 2023, 45(3): A1239-A1270.

[^Yedidia2003]:
Yedidia, J.S., Freeman, W.T., Weiss, Y., 2003. Understanding belief propagation and its generalizations, in: Exploring Artificial Intelligence in the New Millennium. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, pp. 239–269.

[^Wang2024]:
Wang, Y., Zhang, Y.E., Pan, F., Zhang, P., 2024. Tensor Network Message Passing. Phys. Rev. Lett. 132, 117401. https://doi.org/10.1103/PhysRevLett.132.117401

[^Tindall2023]:
Tindall, J., Fishman, M.T., 2023. Gauging tensor networks with belief propagation. SciPost Phys. 15, 222. https://doi.org/10.21468/SciPostPhys.15.6.222

[^Evenbly2024]:
Evenbly, G., Pancotti, N., Milsted, A., Gray, J., Chan, G.K.-L., 2024. Loop Series Expansions for Tensor Networks. https://doi.org/10.48550/arXiv.2409.03108
2 changes: 1 addition & 1 deletion examples/hard-core-lattice-gas/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mars = marginals(pmodel)
show_graph(SimpleGraph(graph), sites; vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph)))
# The can see the sites at the corner is more likely to be occupied.
# To obtain two-site correlations, one can set the variables to query marginal probabilities manually.
pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)])
pmodel2 = TensorNetworkModel(problem, β; unity_tensors_labels = [[e.src, e.dst] for e in edges(graph)])
mars = marginals(pmodel2);

# We show the probability that both sites on an edge are not occupied
Expand Down
5 changes: 1 addition & 4 deletions ext/TensorInferenceCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
module TensorInferenceCUDAExt
using CUDA: CuArray
import CUDA
import TensorInference: match_arraytype, keep_only!, onehot_like, togpu
import TensorInference: keep_only!, onehot_like, togpu

function onehot_like(A::CuArray, j)
mask = zero(A)
CUDA.@allowscalar mask[j] = one(eltype(mask))
return mask
end

# NOTE: this interface should be in OMEinsum
match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target)

function keep_only!(x::CuArray{T}, j) where T
CUDA.@allowscalar hotvalue = x[j]
fill!(x, zero(T))
Expand Down
101 changes: 19 additions & 82 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ $(TYPEDEF)
Probabilistic modeling with a tensor network.

### Fields
* `vars` are the degrees of freedom in the tensor network.
* `nvars` are the number of variables in the tensor network.
* `code` is the tensor network contraction pattern.
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`.
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
* `mars` is a vector, each element is a vector of variables to compute marginal probabilities.
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
"""
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
vars::Vector{LT}
struct TensorNetworkModel{ET, MT <: AbstractArray}
nvars::Int
code::ET
tensors::Vector{MT}
evidence::Dict{LT, Int}
mars::Vector{Vector{LT}}
evidence::Dict{Int, Int}
unity_tensors_idx::Vector{Int}
end

"""
Expand All @@ -78,7 +78,7 @@ end

function Base.show(io::IO, tn::TensorNetworkModel)
open = getiyv(tn.code)
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")
variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ")
tc, sc, rw = contraction_complexity(tn)
println(io, "$(typeof(tn))")
println(io, "variables: $variables")
Expand Down Expand Up @@ -110,102 +110,42 @@ $(TYPEDSIGNATURES)
* `evidence` is a dictionary of evidences, the values are integers start counting from 0.
* `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms.
* `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above.
* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
* `unity_tensors_labels` is a list of labels for the unity tensors. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
"""
function TensorNetworkModel(
model::UAIModel;
model::UAIModel{ET, FT};
openvars = (),
evidence = Dict{Int,Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[i] for i=1:model.nvars]
)::TensorNetworkModel
return TensorNetworkModel(
1:(model.nvars),
model.cards,
model.factors;
openvars,
evidence,
optimizer,
simplifier,
mars
)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
vars::AbstractVector{LT},
cards::AbstractVector{Int},
factors::Vector{<:Factor{T}};
openvars = (),
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {T, LT}
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
# e.g.
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...]
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
vars::AbstractVector{LT},
rawcode::EinCode,
tensors::Vector{<:AbstractArray};
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {LT}
unity_tensors_labels = [[i] for i=1:model.nvars]
) where {ET, FT}
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
# The 1st argument is the contraction pattern to be optimized (without contraction order).
# The 2nd arugment is the size dictionary, which is a label-integer dictionary.
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
model::UAIModel{T}, code;
evidence = Dict{Int,Int}(),
mars = [[i] for i=1:model.nvars],
vars = [1:model.nvars...]
)::TensorNetworkModel where{T}
@debug "constructing tensor network model from code"
tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...]

return TensorNetworkModel(vars, code, tensors, evidence, mars)
return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
end

"""
$(TYPEDSIGNATURES)

Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms.
"""
get_vars(tn::TensorNetworkModel)::Vector = tn.vars
get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars

"""
$(TYPEDSIGNATURES)

Get the cardinalities of variables in this tensor network.
Get the ardinalities of variables in this tensor network.
"""
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
vars = get_vars(tn)
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
[fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars]
end

chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
Expand Down Expand Up @@ -250,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model.
"""
function OMEinsum.contraction_complexity(tn::TensorNetworkModel)
return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true))))
end

# adapt array type with the target array type
match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target)
end
7 changes: 6 additions & 1 deletion src/RescaledArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ end
Base.size(arr::RescaledArray) = size(arr.normalized_value)
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)

match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))
function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T}
return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero))
end
# The following two APIs are required by OMEinsum
Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r)
Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value))
16 changes: 6 additions & 10 deletions src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ $(EXPORTS)
module TensorInference

using OMEinsum, LinearAlgebra
using OMEinsum: CacheTree, cached_einsum
using DocStringExtensions, TropicalNumbers
# The Tropical GEMM support
using StatsBase
Expand Down Expand Up @@ -40,8 +41,11 @@ export MMAPModel
# for ProblemReductions
export update_temperature

# belief propagation
export BeliefPropgation, belief_propagate

# utils
export random_matrix_product_state
export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai

include("Core.jl")
include("RescaledArray.jl")
Expand All @@ -51,14 +55,6 @@ include("map.jl")
include("mmap.jl")
include("sampling.jl")
include("cspmodels.jl")

# import PrecompileTools
# PrecompileTools.@setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# PrecompileTools.@compile_workload begin
# include("../example/asia-network/main.jl")
# end
# end
include("belief.jl")

end # module
Loading
Loading