From b32bdfb35a85568af96cf42b524c99958c8a603f Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 28 Apr 2025 19:24:39 +0800 Subject: [PATCH 01/15] clean up --- Project.toml | 2 - examples/hard-core-lattice-gas/main.jl | 2 +- src/Core.jl | 79 ++++---------------------- src/TensorInference.jl | 9 --- src/cspmodels.jl | 10 ++-- src/map.jl | 8 ++- src/mar.jl | 9 +-- src/utils.jl | 2 +- test/cspmodels.jl | 4 +- test/map.jl | 6 +- test/mar.jl | 6 +- test/sampling.jl | 6 +- 12 files changed, 39 insertions(+), 104 deletions(-) diff --git a/Project.toml b/Project.toml index 90c66c5..280abc7 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProblemReductions = "899c297d-f7d2-4ebf-8815-a35996def416" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -28,7 +27,6 @@ DocStringExtensions = "0.8.6, 0.9" LinearAlgebra = "1" OMEinsum = "0.8" Pkg = "1" -PrecompileTools = "1" PrettyTables = "2" ProblemReductions = "0.3" StatsBase = "0.34" diff --git a/examples/hard-core-lattice-gas/main.jl b/examples/hard-core-lattice-gas/main.jl index 14cc289..0739442 100644 --- a/examples/hard-core-lattice-gas/main.jl +++ b/examples/hard-core-lattice-gas/main.jl @@ -62,7 +62,7 @@ mars = marginals(pmodel) show_graph(SimpleGraph(graph), sites; vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph))) # The can see the sites at the corner is more likely to be occupied. # To obtain two-site correlations, one can set the variables to query marginal probabilities manually. -pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)]) +pmodel2 = TensorNetworkModel(problem, β; unity_tensors_labels = [[e.src, e.dst] for e in edges(graph)]) mars = marginals(pmodel2); # We show the probability that both sites on an edge are not occupied diff --git a/src/Core.jl b/src/Core.jl index dcaf26f..c8c83d3 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -47,16 +47,16 @@ Probabilistic modeling with a tensor network. ### Fields * `vars` are the degrees of freedom in the tensor network. * `code` is the tensor network contraction pattern. -* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`. +* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`. * `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values. -* `mars` is a vector, each element is a vector of variables to compute marginal probabilities. +* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities. """ struct TensorNetworkModel{LT, ET, MT <: AbstractArray} vars::Vector{LT} code::ET tensors::Vector{MT} evidence::Dict{LT, Int} - mars::Vector{Vector{LT}} + unity_tensors_idx::Vector{Int} end """ @@ -110,84 +110,25 @@ $(TYPEDSIGNATURES) * `evidence` is a dictionary of evidences, the values are integers start counting from 0. * `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms. * `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above. -* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity. +* `unity_tensors_labels` is a list of labels for the unity tensors. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity. """ function TensorNetworkModel( - model::UAIModel; + model::UAIModel{ET, FT}; openvars = (), evidence = Dict{Int,Int}(), optimizer = GreedyMethod(), simplifier = nothing, - mars = [[i] for i=1:model.nvars] -)::TensorNetworkModel - return TensorNetworkModel( - 1:(model.nvars), - model.cards, - model.factors; - openvars, - evidence, - optimizer, - simplifier, - mars - ) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - vars::AbstractVector{LT}, - cards::AbstractVector{Int}, - factors::Vector{<:Factor{T}}; - openvars = (), - evidence = Dict{LT, Int}(), - optimizer = GreedyMethod(), - simplifier = nothing, - mars = [[v] for v in vars] -)::TensorNetworkModel where {T, LT} - # The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors, - # The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor, - # e.g. - # `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication. - rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors - tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...] - return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - vars::AbstractVector{LT}, - rawcode::EinCode, - tensors::Vector{<:AbstractArray}; - evidence = Dict{LT, Int}(), - optimizer = GreedyMethod(), - simplifier = nothing, - mars = [[v] for v in vars] -)::TensorNetworkModel where {LT} + unity_tensors_labels = [[i] for i=1:model.nvars] +) where {ET, FT} # `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified. # The 1st argument is the contraction pattern to be optimized (without contraction order). # The 2nd arugment is the size dictionary, which is a label-integer dictionary. # The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify. + rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors + tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...] size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors) code = optimize_code(rawcode, size_dict, optimizer, simplifier) - TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - model::UAIModel{T}, code; - evidence = Dict{Int,Int}(), - mars = [[i] for i=1:model.nvars], - vars = [1:model.nvars...] -)::TensorNetworkModel where{T} - @debug "constructing tensor network model from code" - tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...] - - return TensorNetworkModel(vars, code, tensors, evidence, mars) + return TensorNetworkModel(collect(Int, 1:model.nvars), code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels))) end """ diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 19125ba..31d53bd 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -52,13 +52,4 @@ include("mmap.jl") include("sampling.jl") include("cspmodels.jl") -# import PrecompileTools -# PrecompileTools.@setup_workload begin -# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the -# # precompile file and potentially make loading faster. -# PrecompileTools.@compile_workload begin -# include("../example/asia-network/main.jl") -# end -# end - end # module diff --git a/src/cspmodels.jl b/src/cspmodels.jl index 128f684..0ce4dd8 100644 --- a/src/cspmodels.jl +++ b/src/cspmodels.jl @@ -28,10 +28,11 @@ Convert a constraint satisfiability problem (or energy model) to a probabilistic * `mars` is the list of variables to be marginalized. """ function TensorNetworkModel(problem::ConstraintSatisfactionProblem, β::T; evidence::Dict=Dict{Int,Int}(), - optimizer=GreedyMethod(), openvars=Int[], simplifier=nothing, mars=[[l] for l in variables(problem)]) where T <: Real + optimizer=GreedyMethod(), openvars=Int[], simplifier=nothing, unity_tensors_labels = [[l] for l in variables(problem)]) where T <: Real tensors, ixs = generate_tensors(β, problem) factors = [Factor((ix...,), t) for (ix, t) in zip(ixs, tensors)] - return TensorNetworkModel(variables(problem), fill(num_flavors(problem), num_variables(problem)), factors; openvars, evidence, optimizer, simplifier, mars) + model = UAIModel(num_variables(problem), fill(num_flavors(problem), num_variables(problem)), factors) + return TensorNetworkModel(model; openvars, evidence, optimizer, simplifier, unity_tensors_labels) end """ @@ -47,8 +48,9 @@ The program will regenerate tensors from the problem, without repeated optimizin """ function update_temperature(tnet::TensorNetworkModel, problem::ConstraintSatisfactionProblem, β::Real) tensors, ixs = generate_tensors(β, problem) - alltensors = [tnet.tensors[1:length(tnet.mars)]..., tensors...] - return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.mars) + @assert tnet.unity_tensors_idx == collect(1:length(tnet.unity_tensors_idx)) "The target tensor network can not be updated! Got `unity_tensors_idx = $(tnet.unity_tensors_idx)`" + alltensors = [tnet.tensors[tnet.unity_tensors_idx]..., tensors...] + return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx) end function MMAPModel(problem::ConstraintSatisfactionProblem, β::Real; diff --git a/src/map.jl b/src/map.jl index 372133f..e2f2987 100644 --- a/src/map.jl +++ b/src/map.jl @@ -53,13 +53,15 @@ $(TYPEDSIGNATURES) Returns the largest log-probability and the most probable configuration. """ function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector} - expected_mars = [[l] for l in get_vars(tn)] - @assert tn.mars[1:length(expected_mars)] == expected_mars "To get the the most probable configuration, the leading elements of `tn.vars` must be `$expected_mars`" + ixs = OMEinsum.getixsv(tn.code) + unity_labels = ixs[tn.unity_tensors_idx] + indices = [findfirst(==([l]), unity_labels) for l in get_vars(tn)] + @assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables" vars = get_vars(tn) tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false)) logp, grads = cost_and_gradient(tn.code, tensors) # use Array to convert CuArray to CPU arrays - return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars)) + return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tn.unity_tensors_idx[indices[k]]]) - 1, 1:length(vars)) end """ diff --git a/src/mar.jl b/src/mar.jl index a436d3d..6ecac80 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -130,7 +130,7 @@ are their respective marginals. A marginal is a probability distribution over a subset of variables, obtained by integrating or summing over the remaining variables in the model. By default, the function returns the marginals of all individual variables. To specify which marginal variables to query, set the -`mars` field when constructing a [`TensorNetworkModel`](@ref). Note that +`unity_tensors_labels` field when constructing a [`TensorNetworkModel`](@ref). Note that the choice of marginal variables will affect the contraction order of the tensor network. @@ -158,7 +158,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries: [7] => [0.145092, 0.854908] [2] => [0.05, 0.95] -julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]); +julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), unity_tensors_labels = [[2, 3], [3, 4]]); julia> marginals(tn2) Dict{Vector{Int64}, Matrix{Float64}} with 2 entries: @@ -186,9 +186,10 @@ function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dic # sometimes, the cost can overflow, then we need to rescale the tensors during contraction. cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale)) @debug "cost = $cost" + ixs = OMEinsum.getixsv(tn.code) if rescale - return Dict(zip(tn.mars, LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1))) + return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.unity_tensors_idx)], :normalized_value), 1))) else - return Dict(zip(tn.mars, LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1))) + return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(grads[1:length(tn.unity_tensors_idx)], 1))) end end diff --git a/src/utils.jl b/src/utils.jl index ce90819..b1f2f28 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -356,7 +356,7 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()), tensors, Dict{Int, Int}(), - Vector{Int}[[i] for i=1:n] + collect(1:n) ) end random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) diff --git a/test/cspmodels.jl b/test/cspmodels.jl index c7b559b..cc26457 100644 --- a/test/cspmodels.jl +++ b/test/cspmodels.jl @@ -7,7 +7,7 @@ using GenericTensorNetworks β = 2.0 g = GenericTensorNetworks.Graphs.smallgraph(:petersen) problem = IndependentSet(g) - model = TensorNetworkModel(problem, β; mars=[[2, 3]]) + model = TensorNetworkModel(problem, β; unity_tensors_labels = [[2, 3]]) mars = marginals(model)[[2, 3]] problem2 = IndependentSet(g) mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(GenericTensorNetwork(problem2; openvertices=[2, 3]), PartitionFunction(β)), 1) @@ -28,7 +28,7 @@ using GenericTensorNetworks β = 1.0 problem = SpinGlass(g, -ones(Int, ne(g)), zeros(Int, nv(g))) - model = TensorNetworkModel(problem, β; mars=[[2, 3]]) + model = TensorNetworkModel(problem, β; unity_tensors_labels = [[2, 3]]) samples = sample(model, 100) @test sum(energy.(Ref(problem), samples))/100 <= -14 end \ No newline at end of file diff --git a/test/map.jl b/test/map.jl index abc0c98..2ee148d 100644 --- a/test/map.jl +++ b/test/map.jl @@ -2,16 +2,14 @@ using Test using OMEinsum using TensorInference -@testset "load from code" begin +@testset "load from model" begin model = problem_from_artifact("uai2014", "MAR", "Promedus", 14) tn1 = TensorNetworkModel(read_model(model); evidence=read_evidence(model), optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80)) - tn2 = TensorNetworkModel(read_model(model), tn1.code, evidence=read_evidence(model)) - - @test tn1.code == tn2.code + @test tn1 isa TensorNetworkModel end @testset "gradient-based tensor network solvers" begin diff --git a/test/mar.jl b/test/mar.jl index 7aa10b4..f016ac6 100644 --- a/test/mar.jl +++ b/test/mar.jl @@ -116,7 +116,7 @@ end 0.1 0.3 0.2 0.9 """) n = 10000 - tnet = TensorNetworkModel(model; mars=[[2, 3], [3, 4]]) + tnet = TensorNetworkModel(model; unity_tensors_labels = [[2, 3], [3, 4]]) mars = marginals(tnet) tnet23 = TensorNetworkModel(model; openvars=[2,3]) tnet34 = TensorNetworkModel(model; openvars=[3,4]) @@ -124,8 +124,8 @@ end @test mars[[3, 4]] ≈ probability(tnet34) vars = [[2, 4], [3, 5]] - tnet1 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>1)) - tnet2 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>0)) + tnet1 = TensorNetworkModel(model; unity_tensors_labels = vars, evidence=Dict(3=>1)) + tnet2 = TensorNetworkModel(model; unity_tensors_labels = vars, evidence=Dict(3=>0)) mars1 = marginals(tnet1) mars2 = marginals(tnet2) update_evidence!(tnet1, Dict(3=>0)) diff --git a/test/sampling.jl b/test/sampling.jl index 0f06320..005dc87 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -70,10 +70,12 @@ end Random.seed!(140) mps = random_matrix_product_state(n, chi) num_samples = 10000 + ixs = OMEinsum.getixsv(mps.code) + @show ixs samples = map(1:num_samples) do i - sample(mps, 1; queryvars=vcat(mps.mars...)).samples[:,1] + sample(mps, 1; queryvars=collect(1:n)).samples[:,1] end - samples = sample(mps, num_samples; queryvars=vcat(mps.mars...)) + samples = sample(mps, num_samples; queryvars=collect(1:n)) indices = map(samples) do sample sum(i->sample[i] * 2^(i-1), 1:n) + 1 end From 563950b4083dcfed18cd2dd9bd5d3b49fb7e78c1 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 28 Apr 2025 19:44:20 +0800 Subject: [PATCH 02/15] update --- src/map.jl | 14 +++++++++----- src/mar.jl | 13 ++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/map.jl b/src/map.jl index e2f2987..f4142af 100644 --- a/src/map.jl +++ b/src/map.jl @@ -53,15 +53,19 @@ $(TYPEDSIGNATURES) Returns the largest log-probability and the most probable configuration. """ function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector} - ixs = OMEinsum.getixsv(tn.code) - unity_labels = ixs[tn.unity_tensors_idx] - indices = [findfirst(==([l]), unity_labels) for l in get_vars(tn)] - @assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables" vars = get_vars(tn) + tensor_indices = check_queryvars(tn, [[v] for v in vars]) tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false)) logp, grads = cost_and_gradient(tn.code, tensors) # use Array to convert CuArray to CPU arrays - return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tn.unity_tensors_idx[indices[k]]]) - 1, 1:length(vars)) + return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tensor_indices[k]]) - 1, 1:length(vars)) +end +# check if the queryvars are included in the unity tensors labels, if yes, return the indices of the unity tensors +function check_queryvars(tn::TensorNetworkModel, queryvars::AbstractVector{Vector{Int}}) + ixs = OMEinsum.getixsv(tn.code) + indices = [findfirst(==(l), ixs[tn.unity_tensors_idx]) for l in queryvars] + @assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables. Query variables: $queryvars, Unity tensors labels: $(ixs[tn.unity_tensors_idx])" + return tn.unity_tensors_idx[indices] end """ diff --git a/src/mar.jl b/src/mar.jl index 6ecac80..b1b65ae 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -136,8 +136,10 @@ tensor network. ### Arguments - `tn`: The [`TensorNetworkModel`](@ref) to query. -- `usecuda`: Specifies whether to use CUDA for tensor contraction. -- `rescale`: Specifies whether to rescale the tensors during contraction. + +### Keyword Arguments +- `usecuda::Bool`: Specifies whether to use CUDA for tensor contraction. +- `rescale::Bool`: Specifies whether to rescale the tensors during contraction. ### Example The following example is taken from [`examples/asia-network/main.jl`](https://tensorbfs.github.io/TensorInference.jl/dev/generated/asia-network/main/). @@ -187,9 +189,10 @@ function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dic cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale)) @debug "cost = $cost" ixs = OMEinsum.getixsv(tn.code) + queryvars = ixs[tn.unity_tensors_idx] if rescale - return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.unity_tensors_idx)], :normalized_value), 1))) + return Dict(zip(queryvars, LinearAlgebra.normalize!.(getfield.(grads[tn.unity_tensors_idx], :normalized_value), 1))) else - return Dict(zip(ixs[tn.unity_tensors_idx], LinearAlgebra.normalize!.(grads[1:length(tn.unity_tensors_idx)], 1))) + return Dict(zip(queryvars, LinearAlgebra.normalize!.(grads[tn.unity_tensors_idx], 1))) end -end +end \ No newline at end of file From c57bc8a546761561f6ede2cab147255ae543e077 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 28 Apr 2025 20:35:54 +0800 Subject: [PATCH 03/15] vars -> nvars --- src/Core.jl | 19 +++++++++---------- src/cspmodels.jl | 2 +- src/map.jl | 5 ++--- src/utils.jl | 2 +- test/mar.jl | 1 - test/pr.jl | 1 - 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/Core.jl b/src/Core.jl index c8c83d3..39866fe 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -45,17 +45,17 @@ $(TYPEDEF) Probabilistic modeling with a tensor network. ### Fields -* `vars` are the degrees of freedom in the tensor network. +* `nvars` are the number of variables in the tensor network. * `code` is the tensor network contraction pattern. * `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`. * `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values. * `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities. """ -struct TensorNetworkModel{LT, ET, MT <: AbstractArray} - vars::Vector{LT} +struct TensorNetworkModel{ET, MT <: AbstractArray} + nvars::Int code::ET tensors::Vector{MT} - evidence::Dict{LT, Int} + evidence::Dict{Int, Int} unity_tensors_idx::Vector{Int} end @@ -78,7 +78,7 @@ end function Base.show(io::IO, tn::TensorNetworkModel) open = getiyv(tn.code) - variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ") + variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ") tc, sc, rw = contraction_complexity(tn) println(io, "$(typeof(tn))") println(io, "variables: $variables") @@ -128,7 +128,7 @@ function TensorNetworkModel( tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...] size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors) code = optimize_code(rawcode, size_dict, optimizer, simplifier) - return TensorNetworkModel(collect(Int, 1:model.nvars), code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels))) + return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels))) end """ @@ -136,17 +136,16 @@ $(TYPEDSIGNATURES) Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms. """ -get_vars(tn::TensorNetworkModel)::Vector = tn.vars +get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars """ $(TYPEDSIGNATURES) -Get the cardinalities of variables in this tensor network. +Get the ardinalities of variables in this tensor network. """ function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector - vars = get_vars(tn) size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors) - [fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)] + [fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars] end chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence) diff --git a/src/cspmodels.jl b/src/cspmodels.jl index 0ce4dd8..dbd7b58 100644 --- a/src/cspmodels.jl +++ b/src/cspmodels.jl @@ -50,7 +50,7 @@ function update_temperature(tnet::TensorNetworkModel, problem::ConstraintSatisfa tensors, ixs = generate_tensors(β, problem) @assert tnet.unity_tensors_idx == collect(1:length(tnet.unity_tensors_idx)) "The target tensor network can not be updated! Got `unity_tensors_idx = $(tnet.unity_tensors_idx)`" alltensors = [tnet.tensors[tnet.unity_tensors_idx]..., tensors...] - return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx) + return TensorNetworkModel(tnet.nvars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx) end function MMAPModel(problem::ConstraintSatisfactionProblem, β::Real; diff --git a/src/map.jl b/src/map.jl index f4142af..a1f5762 100644 --- a/src/map.jl +++ b/src/map.jl @@ -53,12 +53,11 @@ $(TYPEDSIGNATURES) Returns the largest log-probability and the most probable configuration. """ function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector} - vars = get_vars(tn) - tensor_indices = check_queryvars(tn, [[v] for v in vars]) + tensor_indices = check_queryvars(tn, [[v] for v in 1:tn.nvars]) tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false)) logp, grads = cost_and_gradient(tn.code, tensors) # use Array to convert CuArray to CPU arrays - return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[tensor_indices[k]]) - 1, 1:length(vars)) + return content(Array(logp)[]), map(k -> haskey(tn.evidence, k) ? tn.evidence[k] : argmax(grads[tensor_indices[k]]) - 1, 1:tn.nvars) end # check if the queryvars are included in the unity tensors labels, if yes, return the indices of the unity tensors function check_queryvars(tn::TensorNetworkModel, queryvars::AbstractVector{Vector{Int}}) diff --git a/src/utils.jl b/src/utils.jl index b1f2f28..0c67e3b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -352,7 +352,7 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]]) tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...] return TensorNetworkModel( - collect(1:3n-2), + 3n-2, optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()), tensors, Dict{Int, Int}(), diff --git a/test/mar.jl b/test/mar.jl index f016ac6..1302ca7 100644 --- a/test/mar.jl +++ b/test/mar.jl @@ -1,6 +1,5 @@ using Test using OMEinsum -using KaHyPar using TensorInference @testset "composite number" begin diff --git a/test/pr.jl b/test/pr.jl index 53c1c1e..f6c0c08 100644 --- a/test/pr.jl +++ b/test/pr.jl @@ -1,6 +1,5 @@ using Test using OMEinsum -using KaHyPar using TensorInference @testset "UAI Reference Solution Comparison" begin From 6f96133e08aeca57af93a8ae091816bd1b000f04 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 28 Apr 2025 22:03:50 +0800 Subject: [PATCH 04/15] update --- src/TensorInference.jl | 1 + src/belief.jl | 47 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 src/belief.jl diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 31d53bd..38667a8 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -51,5 +51,6 @@ include("map.jl") include("mmap.jl") include("sampling.jl") include("cspmodels.jl") +include("belief.jl") end # module diff --git a/src/belief.jl b/src/belief.jl new file mode 100644 index 0000000..c357069 --- /dev/null +++ b/src/belief.jl @@ -0,0 +1,47 @@ +struct BPState{T, VT<:AbstractVector{T}} + t2v::Vector{Vector{Int}} # a mapping from tensors to variables + v2t::Vector{Vector{Int}} # a mapping from variables to tensors + edges_vectors::Vector{Vector{VT}} # each tensor is associated with a vector of vectors, one for each neighbor +end + +function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T + v2t = [Int[] for _ in 1:n] + edges_vectors = [Vector{VT}[] for _ in 1:n] + for (i, edge) in enumerate(t2v) + for v in edge + push!(v2t[v], i) + push!(edges_vectors[i], ones(T, size_dict[v])) + end + end + return BPState(t2v, v2t, edges_vectors) +end + +# belief propagation, update the tensors on the edges of the tensor network +function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T + # collect the messages from the neighbors + messages = [similar(bpstate.edges_vectors[it]) for it in 1:length(bpstate.t2v)] + for (it, vs) in enumerate(bpstate.t2v) + for (iv, v) in enumerate(vs) + messages[it][iv] = tn.tensors[v] + end + end + # update the tensors on the edges of the tensor network + for (it, vs) in enumerate(bpstate.t2v) + # update the tensor + for (iv, v) in enumerate(vs) + bpstate.edges_vectors[it][iv] = zeros(T, size_dict[v]) + for (j, w) in enumerate(vs) + if j != iv + bpstate.edges_vectors[it][iv] += messages[j][iv] * messages[j][iv] + end + end + end + end +end + +function tensor_product() +end + +function belief_propagation(tn::TensorNetworkModel{T}) where T + return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict)) +end From 36e047e96b6fc4cc3cd6f9e9d840e069252b6410 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 29 Apr 2025 00:23:35 +0800 Subject: [PATCH 05/15] update --- src/TensorInference.jl | 3 +++ src/belief.jl | 54 ++++++++++++++++++++++++++++++++++++++---- test/belief.jl | 36 ++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 test/belief.jl diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 38667a8..8fc7377 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -40,6 +40,9 @@ export MMAPModel # for ProblemReductions export update_temperature +# belief propagation +export belief_propagation + # utils export random_matrix_product_state diff --git a/src/belief.jl b/src/belief.jl index c357069..beb2ce3 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -1,7 +1,56 @@ struct BPState{T, VT<:AbstractVector{T}} t2v::Vector{Vector{Int}} # a mapping from tensors to variables v2t::Vector{Vector{Int}} # a mapping from variables to tensors - edges_vectors::Vector{Vector{VT}} # each tensor is associated with a vector of vectors, one for each neighbor + tensors::Vector{AbstractArray{T}} # the tensors + message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages + message_out::Vector{Vector{VT}} # the outgoing messages +end + +# message_in -> message_out +function process_message!(bp::BPState) + for (ov, iv) in zip(bp.message_out, bp.message_in) + _process_message!(ov, iv) + end +end +function _process_message!(ov::Vector, iv::Vector) + # process the message, TODO: speed up if needed! + for (i, v) in enumerate(ov) + fill!(v, one(eltype(v))) # clear the output vector + for (j, u) in enumerate(iv) + j != i && (v .*= u) + end + end +end + +function collect_message!(bp::BPState) + for (it, t) in enumerate(bp.t2v) + _collect_message!(vectors_on_tensor(bp.message_out, bp, it), t, vectors_on_tensor(bp.message_in, bp, it)) + end +end +# collect the vectors associated with the target tensor +function vectors_on_tensor(messages, bp::BPState, it::Int) + return map(bp.t2v[it]) do v + # the message goes to the idx-th tensor from variable v + messages[v][findfirst(==(it), bp.v2t[v])] + end +end +function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector) + @assert length(vectors_out) == length(vectors_in) == ndims(t) + # TODO: speed up if needed! + code = star_code(length(vectors_in)) + cost, gradient = cost_and_gradient(code, [t, vectors_in...]) + for (o, g) in zip(vectors_out, gradient[2:end]) + o .= g + end + return cost +end +function star_code(n::Int) + ix1, ixrest = collect(1:n), [[i] for i in 1:n] + ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n))) + for i in 2:n + ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect(i+1:n))) + end + return ne end function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T @@ -39,9 +88,6 @@ function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_ end end -function tensor_product() -end - function belief_propagation(tn::TensorNetworkModel{T}) where T return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict)) end diff --git a/test/belief.jl b/test/belief.jl new file mode 100644 index 0000000..3c1c7a7 --- /dev/null +++ b/test/belief.jl @@ -0,0 +1,36 @@ +using TensorInference, Test + +@testset "process message" begin + mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] + mo_expected = [[6, 12, 20], [3, 8, 15], [2, 6, 12]] + mo = similar.(mi) + TensorInference._process_message!(mo, mi) + @test mo == mo_expected +end + +@testset "star code" begin + code = TensorInference.star_code(3) + c1, c2, c3, c4 = [DynamicNestedEinsum{Int}(i) for i in 1:4] + ne1 = DynamicNestedEinsum([c1, c2], DynamicEinCode([[1, 2, 3], [1]], [2, 3])) + ne2 = DynamicNestedEinsum([ne1, c3], DynamicEinCode([[2, 3], [2]], [3])) + ne3 = DynamicNestedEinsum([ne2, c4], DynamicEinCode([[3], [3]], Int[])) + @test code == ne3 + t = randn(2, 2, 2) + v1 = randn(2) + v2 = randn(2) + v3 = randn(2) + vectors_out = [similar(v1), similar(v2), similar(v3)] + TensorInference._collect_message!(vectors_out, t, [v1, v2, v3]) + @test vectors_out[1] ≈ reshape(t, 2, 4) * kron(v3, v2) # NOTE: v3 is the little end + @test vectors_out[2] ≈ vec(v1' * reshape(reshape(t, 4, 2) * v3, 2, 2)) + @test vectors_out[3] ≈ vec(kron(v2, v1)' * reshape(t, 4, 2)) +end + +@testset "belief propagation" begin + n = 5 + chi = 3 + Random.seed!(140) + mps = random_matrix_product_state(n, chi) + model = TensorNetworkModel(mps) + state = belief_propagation(model) +end \ No newline at end of file From 0806a5fa609e42943f53c5035c2e9ed213d5d87b Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 29 Apr 2025 01:43:28 +0800 Subject: [PATCH 06/15] update --- src/TensorInference.jl | 4 +- src/belief.jl | 90 +++++++++++++++++++++++++----------------- src/utils.jl | 49 +++++++++++++++++++---- test/belief.jl | 22 +++++++++-- test/runtests.jl | 8 ++++ test/sampling.jl | 1 - test/utils.jl | 12 ++++++ 7 files changed, 136 insertions(+), 50 deletions(-) create mode 100644 test/utils.jl diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 8fc7377..72f17f1 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -41,10 +41,10 @@ export MMAPModel export update_temperature # belief propagation -export belief_propagation +export BeliefPropgation, belief_propagate # utils -export random_matrix_product_state +export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai include("Core.jl") include("RescaledArray.jl") diff --git a/src/belief.jl b/src/belief.jl index beb2ce3..748ce48 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -1,7 +1,26 @@ -struct BPState{T, VT<:AbstractVector{T}} +struct BeliefPropgation{T} t2v::Vector{Vector{Int}} # a mapping from tensors to variables v2t::Vector{Vector{Int}} # a mapping from variables to tensors tensors::Vector{AbstractArray{T}} # the tensors +end +num_tensors(bp::BeliefPropgation) = length(bp.t2v) +ProblemReductions.num_variables(bp::BeliefPropgation) = length(bp.v2t) + +function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T + # initialize the inverse mapping + v2t = [Int[] for _ in 1:nvars] + for (i, edge) in enumerate(t2v) + for v in edge + push!(v2t[v], i) + end + end + return BeliefPropgation(t2v, v2t, tensors) +end +function BeliefPropgation(uai::UAIModel{T}) where T + return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors]) +end + +struct BPState{T, VT<:AbstractVector{T}} message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages message_out::Vector{Vector{VT}} # the outgoing messages end @@ -22,20 +41,20 @@ function _process_message!(ov::Vector, iv::Vector) end end -function collect_message!(bp::BPState) - for (it, t) in enumerate(bp.t2v) - _collect_message!(vectors_on_tensor(bp.message_out, bp, it), t, vectors_on_tensor(bp.message_in, bp, it)) +function collect_message!(bp::BeliefPropgation, state::BPState) + for it in 1:num_tensors(bp) + _collect_message!(vectors_on_tensor(state.message_out, bp, it), bp.tensors[it], vectors_on_tensor(state.message_in, bp, it)) end end # collect the vectors associated with the target tensor -function vectors_on_tensor(messages, bp::BPState, it::Int) +function vectors_on_tensor(messages, bp::BeliefPropgation, it::Int) return map(bp.t2v[it]) do v # the message goes to the idx-th tensor from variable v messages[v][findfirst(==(it), bp.v2t[v])] end end function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector) - @assert length(vectors_out) == length(vectors_in) == ndims(t) + @assert length(vectors_out) == length(vectors_in) == ndims(t) "dimensions mismatch: $(length(vectors_out)), $(length(vectors_in)), $(ndims(t))" # TODO: speed up if needed! code = star_code(length(vectors_in)) cost, gradient = cost_and_gradient(code, [t, vectors_in...]) @@ -44,6 +63,8 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve end return cost end + +# star code: contract a tensor with multiple vectors, one for each dimension function star_code(n::Int) ix1, ixrest = collect(1:n), [[i] for i in 1:n] ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n))) @@ -53,41 +74,38 @@ function star_code(n::Int) return ne end -function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T - v2t = [Int[] for _ in 1:n] - edges_vectors = [Vector{VT}[] for _ in 1:n] - for (i, edge) in enumerate(t2v) - for v in edge - push!(v2t[v], i) - push!(edges_vectors[i], ones(T, size_dict[v])) - end +function initial_state(bp::BeliefPropgation{T}) where T + size_dict = OMEinsum.get_size_dict(bp.t2v, bp.tensors) + edges_vectors = Vector{Vector{T}}[] + for (i, tids) in enumerate(bp.v2t) + push!(edges_vectors, [ones(T, size_dict[i]) for _ in 1:length(tids)]) end - return BPState(t2v, v2t, edges_vectors) + return BPState(deepcopy(edges_vectors), edges_vectors) end # belief propagation, update the tensors on the edges of the tensor network -function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T - # collect the messages from the neighbors - messages = [similar(bpstate.edges_vectors[it]) for it in 1:length(bpstate.t2v)] - for (it, vs) in enumerate(bpstate.t2v) - for (iv, v) in enumerate(vs) - messages[it][iv] = tn.tensors[v] - end - end - # update the tensors on the edges of the tensor network - for (it, vs) in enumerate(bpstate.t2v) - # update the tensor - for (iv, v) in enumerate(vs) - bpstate.edges_vectors[it][iv] = zeros(T, size_dict[v]) - for (j, w) in enumerate(vs) - if j != iv - bpstate.edges_vectors[it][iv] += messages[j][iv] * messages[j][iv] - end - end +function belief_propagate(bp::BeliefPropgation; max_iter::Int=100, tol::Float64=1e-6) + state = initial_state(bp) + info = belief_propagate!(bp, state; max_iter=max_iter, tol=tol) + return state, info +end +struct BPInfo + converged::Bool + iterations::Int +end +function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T + for i in 1:max_iter + process_message!(state) + collect_message!(bp, state) + # check convergence + if all(iv -> all(it -> isapprox(state.message_out[iv][it], state.message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) + return BPInfo(true, i) end end + return BPInfo(false, max_iter) end -function belief_propagation(tn::TensorNetworkModel{T}) where T - return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict)) -end +# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction +function contraction_results(state::BPState{T}) where T + return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 0c67e3b..f9c1794 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -334,6 +334,12 @@ connected in a chain. - `d` is the dimension of the physical indices. """ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) where T + uai = random_matrix_product_uai(T, n, chi, d) + return TensorNetworkModel(uai; optimizer=GreedyMethod()) +end +random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) + +function random_matrix_product_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T # chi ^ (n-1) * (variance^n)^2 == 1/d^n variance = d^(-1/2) * chi^(-1/2+1/2n) tensors = Any[randn(T, d, chi) .* variance] @@ -351,12 +357,41 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher push!(ixs_ket, [virtual_indices_ket[n-1], physical_indices[n]]) push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]]) tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...] - return TensorNetworkModel( - 3n-2, - optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()), - tensors, - Dict{Int, Int}(), - collect(1:n) + size_dict = OMEinsum.get_size_dict(ixs, tensors) + nvars = 3n-2 + return UAIModel( + nvars, + [size_dict[i] for i=1:nvars], + [Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)] ) end -random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) + + +""" +$TYPEDSIGNATURES + +Tensor train (TT) is a tensor network model that is widely used in quantum +many-body physics. This model is different from the matrix product state (MPS) +in that it does not have an extra copy for representing the bra state. +""" +function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T + # chi ^ (n-1) * (variance^n)^2 == 1/d^n + variance = d^(-1/2) * chi^(-1/2+1/2n) + tensors = Any[randn(T, d, chi) .* variance] + physical_indices = collect(1:n) + virtual_indices = collect(n+1:2n-1) + ixs = [[physical_indices[1], virtual_indices[1]]] + for i = 2:n-1 + push!(tensors, randn(T, chi, d, chi) .* variance) + push!(ixs, [virtual_indices[i-1], physical_indices[i], virtual_indices[i]]) + end + push!(tensors, randn(T, chi, d) .* variance) + push!(ixs, [virtual_indices[n-1], physical_indices[n]]) + size_dict = OMEinsum.get_size_dict(ixs, tensors) + nvars = 2n-1 + return UAIModel( + nvars, + [size_dict[i] for i=1:nvars], + [Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)] + ) +end \ No newline at end of file diff --git a/test/belief.jl b/test/belief.jl index 3c1c7a7..10633ee 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -1,4 +1,5 @@ using TensorInference, Test +using OMEinsum @testset "process message" begin mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] @@ -26,11 +27,24 @@ end @test vectors_out[3] ≈ vec(kron(v2, v1)' * reshape(t, 4, 2)) end +@testset "constructor" begin + problem = problem_from_artifact("uai2014", "MAR", "Promedus", 14) + uai = read_model(problem) + bp = BeliefPropgation(uai) + @test length(bp.v2t) == 414 + @test TensorInference.num_tensors(bp) == 414 + @test TensorInference.num_variables(bp) == length(unique(vcat([collect(Int, f.vars) for f in uai.factors]...))) +end + @testset "belief propagation" begin n = 5 chi = 3 - Random.seed!(140) - mps = random_matrix_product_state(n, chi) - model = TensorNetworkModel(mps) - state = belief_propagation(model) + mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi) + bp = BeliefPropgation(mps_uai) + @test TensorInference.initial_state(bp) isa TensorInference.BPState + state, info = belief_propagate(bp) + @show TensorInference.contraction_results(state) + @test info.converged + tnet = TensorNetworkModel(mps_uai) + @show probability(tnet)[] end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6bd7da2..c32af7a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,14 @@ end include("cspmodels.jl") end +@testset "utils" begin + include("utils.jl") +end + +@testset "belief propagation" begin + include("belief.jl") +end + using CUDA if CUDA.functional() include("cuda.jl") diff --git a/test/sampling.jl b/test/sampling.jl index 005dc87..d95b8a2 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -71,7 +71,6 @@ end mps = random_matrix_product_state(n, chi) num_samples = 10000 ixs = OMEinsum.getixsv(mps.code) - @show ixs samples = map(1:num_samples) do i sample(mps, 1; queryvars=collect(1:n)).samples[:,1] end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..35038ac --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,12 @@ +using TensorInference, Test + +@testset "tensor train" begin + tt = random_tensor_train_uai(Float64, 5, 3) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) +end + +@testset "mps" begin + tt = random_matrix_product_uai(Float64, 5, 3) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) +end + From 0350a1716fc948365683258057c99e35744a2e68 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 29 Apr 2025 01:57:04 +0800 Subject: [PATCH 07/15] update --- src/belief.jl | 8 +++++--- test/belief.jl | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/belief.jl b/src/belief.jl index 748ce48..6049914 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -43,7 +43,7 @@ end function collect_message!(bp::BeliefPropgation, state::BPState) for it in 1:num_tensors(bp) - _collect_message!(vectors_on_tensor(state.message_out, bp, it), bp.tensors[it], vectors_on_tensor(state.message_in, bp, it)) + _collect_message!(vectors_on_tensor(state.message_in, bp, it), bp.tensors[it], vectors_on_tensor(state.message_out, bp, it)) end end # collect the vectors associated with the target tensor @@ -94,13 +94,15 @@ struct BPInfo iterations::Int end function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T + pre_message_in = deepcopy(state.message_in) for i in 1:max_iter - process_message!(state) collect_message!(bp, state) + process_message!(state) # check convergence - if all(iv -> all(it -> isapprox(state.message_out[iv][it], state.message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) + if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) return BPInfo(true, i) end + pre_message_in = deepcopy(state.message_in) end return BPInfo(false, max_iter) end diff --git a/test/belief.jl b/test/belief.jl index 10633ee..f5d1317 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -43,8 +43,10 @@ end bp = BeliefPropgation(mps_uai) @test TensorInference.initial_state(bp) isa TensorInference.BPState state, info = belief_propagate(bp) - @show TensorInference.contraction_results(state) @test info.converged + @test info.iterations < 10 + contraction_res = TensorInference.contraction_results(state) tnet = TensorNetworkModel(mps_uai) - @show probability(tnet)[] + expected_result = probability(tnet)[] + @test all(r -> isapprox(r, expected_result), contraction_res) end \ No newline at end of file From b14253fffb7bf0bc19fb21732b9b41001f19c0fc Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 29 Apr 2025 10:01:22 +0800 Subject: [PATCH 08/15] implement marginals --- src/belief.jl | 4 ++++ test/belief.jl | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/belief.jl b/src/belief.jl index 6049914..4524288 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -110,4 +110,8 @@ end # if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction function contraction_results(state::BPState{T}) where T return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] +end + +function marginals(state::BPState{T}) where T + return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in)) end \ No newline at end of file diff --git a/test/belief.jl b/test/belief.jl index f5d1317..9abd0e0 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -49,4 +49,9 @@ end tnet = TensorNetworkModel(mps_uai) expected_result = probability(tnet)[] @test all(r -> isapprox(r, expected_result), contraction_res) + mars = marginals(state) + mars_tnet = marginals(tnet) + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] + end end \ No newline at end of file From 2c41d54c9327e19338e9d723893ef629e808396b Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 29 Apr 2025 10:24:23 +0800 Subject: [PATCH 09/15] fix docs --- docs/src/api/public.md | 4 ++++ docs/src/tensor-networks.md | 19 +++++++++++++++++++ src/belief.jl | 33 ++++++++++++++++++++++++++++++++- src/utils.jl | 6 ++++++ 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/docs/src/api/public.md b/docs/src/api/public.md index 5616b95..ca1e718 100644 --- a/docs/src/api/public.md +++ b/docs/src/api/public.md @@ -43,6 +43,7 @@ RescaledArray TensorNetworkModel ArtifactProblemSpec UAIModel +BeliefPropgation ``` ## Functions @@ -56,6 +57,7 @@ marginals maximum_logp most_probable_config probability +belief_propagate dataset_from_artifact problem_from_artifact read_model @@ -69,4 +71,6 @@ sample update_evidence! update_temperature random_matrix_product_state +random_matrix_product_uai +random_tensor_train_uai ``` diff --git a/docs/src/tensor-networks.md b/docs/src/tensor-networks.md index a86442b..f1f5836 100644 --- a/docs/src/tensor-networks.md +++ b/docs/src/tensor-networks.md @@ -205,6 +205,13 @@ Some of these have been implemented in the [OMEinsum](https://github.com/under-Peter/OMEinsum.jl) package. Please check [Performance Tips](@ref) for more details. +## Belief propagation + +Belief propagation[^Yedidia2003] is a message passing algorithm that can be used to compute the marginals of a probabilistic graphical model. It has close connections with the tensor networks. It can be viewed as a way to gauge the tensor networks[^Tindall2023], and can be combined with tensor networks to achieve better performance[^Wang2024]. + +Belief propagation is an approximate method, and the quality of the approximation can be improved by the loop series expansion[^Evenbly2024]. + + ## References [^Orus2014]: @@ -227,3 +234,15 @@ Some of these have been implemented in the [^Liu2023]: Liu J G, Gao X, Cain M, et al. Computing solution space properties of combinatorial optimization problems via generic tensor networks[J]. SIAM Journal on Scientific Computing, 2023, 45(3): A1239-A1270. + +[^Yedidia2003]: + Yedidia, J.S., Freeman, W.T., Weiss, Y., 2003. Understanding belief propagation and its generalizations, in: Exploring Artificial Intelligence in the New Millennium. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, pp. 239–269. + +[^Wang2024]: + Wang, Y., Zhang, Y.E., Pan, F., Zhang, P., 2024. Tensor Network Message Passing. Phys. Rev. Lett. 132, 117401. https://doi.org/10.1103/PhysRevLett.132.117401 + +[^Tindall2023]: + Tindall, J., Fishman, M.T., 2023. Gauging tensor networks with belief propagation. SciPost Phys. 15, 222. https://doi.org/10.21468/SciPostPhys.15.6.222 + +[^Evenbly2024]: + Evenbly, G., Pancotti, N., Milsted, A., Gray, J., Chan, G.K.-L., 2024. Loop Series Expansions for Tensor Networks. https://doi.org/10.48550/arXiv.2409.03108 diff --git a/src/belief.jl b/src/belief.jl index 4524288..b7c27b9 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -1,3 +1,14 @@ +""" +$TYPEDEF + BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T + +A belief propagation object. + +### Fields +- `t2v::Vector{Vector{Int}}`: a mapping from tensors to variables +- `v2t::Vector{Vector{Int}}`: a mapping from variables to tensors +- `tensors::Vector{AbstractArray{T}}`: the tensors +""" struct BeliefPropgation{T} t2v::Vector{Vector{Int}} # a mapping from tensors to variables v2t::Vector{Vector{Int}} # a mapping from variables to tensors @@ -16,6 +27,12 @@ function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors: end return BeliefPropgation(t2v, v2t, tensors) end + +""" +$(TYPEDSIGNATURES) + +Construct a belief propagation object from a [`UAIModel`](@ref). +""" function BeliefPropgation(uai::UAIModel{T}) where T return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors]) end @@ -83,7 +100,18 @@ function initial_state(bp::BeliefPropgation{T}) where T return BPState(deepcopy(edges_vectors), edges_vectors) end -# belief propagation, update the tensors on the edges of the tensor network +""" +$(TYPEDSIGNATURES) + +Run the belief propagation algorithm, and return the final state and the information about the convergence. + +### Arguments +- `bp::BeliefPropgation`: the belief propagation object + +### Keyword Arguments +- `max_iter::Int=100`: the maximum number of iterations +- `tol::Float64=1e-6`: the tolerance for the convergence +""" function belief_propagate(bp::BeliefPropgation; max_iter::Int=100, tol::Float64=1e-6) state = initial_state(bp) info = belief_propagate!(bp, state; max_iter=max_iter, tol=tol) @@ -112,6 +140,9 @@ function contraction_results(state::BPState{T}) where T return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] end +""" +$(TYPEDSIGNATURES) +""" function marginals(state::BPState{T}) where T return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in)) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index f9c1794..8d7d1f1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -339,6 +339,12 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher end random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) +""" +$TYPEDSIGNATURES + +Generate a random UAIModel that represents a matrix product state (MPS). +Similar to [`random_matrix_product_state`](@ref), but returns the UAIModel directly. +""" function random_matrix_product_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T # chi ^ (n-1) * (variance^n)^2 == 1/d^n variance = d^(-1/2) * chi^(-1/2+1/2n) From d2262fffc7a0d0eda6ee879907851e1304abb5a5 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 1 May 2025 00:38:19 +0800 Subject: [PATCH 10/15] update --- src/belief.jl | 31 ++++++++++++++++++------------- src/utils.jl | 16 ++++++++-------- test/belief.jl | 36 +++++++++++++++++++++++++++--------- test/utils.jl | 3 +++ 4 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/belief.jl b/src/belief.jl index b7c27b9..7e69a04 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -43,24 +43,29 @@ struct BPState{T, VT<:AbstractVector{T}} end # message_in -> message_out -function process_message!(bp::BPState) +function process_message!(bp::BPState; normalize, damping) for (ov, iv) in zip(bp.message_out, bp.message_in) - _process_message!(ov, iv) + _process_message!(ov, iv, normalize, damping) end end -function _process_message!(ov::Vector, iv::Vector) +function _process_message!(ov::Vector, iv::Vector, normalize::Bool, damping) # process the message, TODO: speed up if needed! for (i, v) in enumerate(ov) - fill!(v, one(eltype(v))) # clear the output vector + w = similar(v) + fill!(w, one(eltype(v))) # clear the output vector for (j, u) in enumerate(iv) - j != i && (v .*= u) + j != i && (w .*= u) end + normalize && normalize!(w, 1) + v .= v .* damping + (1 - damping) * w end end -function collect_message!(bp::BeliefPropgation, state::BPState) +function collect_message!(bp::BeliefPropgation, state::BPState; normalize::Bool) for it in 1:num_tensors(bp) - _collect_message!(vectors_on_tensor(state.message_in, bp, it), bp.tensors[it], vectors_on_tensor(state.message_out, bp, it)) + out = vectors_on_tensor(state.message_in, bp, it) + _collect_message!(out, bp.tensors[it], vectors_on_tensor(state.message_out, bp, it)) + normalize && normalize!.(out, 1) end end # collect the vectors associated with the target tensor @@ -78,7 +83,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve for (o, g) in zip(vectors_out, gradient[2:end]) o .= g end - return cost + return cost[] end # star code: contract a tensor with multiple vectors, one for each dimension @@ -112,20 +117,20 @@ Run the belief propagation algorithm, and return the final state and the informa - `max_iter::Int=100`: the maximum number of iterations - `tol::Float64=1e-6`: the tolerance for the convergence """ -function belief_propagate(bp::BeliefPropgation; max_iter::Int=100, tol::Float64=1e-6) +function belief_propagate(bp::BeliefPropgation; kwargs...) state = initial_state(bp) - info = belief_propagate!(bp, state; max_iter=max_iter, tol=tol) + info = belief_propagate!(bp, state; kwargs...) return state, info end struct BPInfo converged::Bool iterations::Int end -function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T +function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol=1e-6, damping=0.2) where T pre_message_in = deepcopy(state.message_in) for i in 1:max_iter - collect_message!(bp, state) - process_message!(state) + collect_message!(bp, state; normalize=true) + process_message!(state; normalize=true, damping=damping) # check convergence if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) return BPInfo(true, i) diff --git a/src/utils.jl b/src/utils.jl index 8d7d1f1..11064b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -380,21 +380,21 @@ Tensor train (TT) is a tensor network model that is widely used in quantum many-body physics. This model is different from the matrix product state (MPS) in that it does not have an extra copy for representing the bra state. """ -function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T +function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2; periodic=false) where T # chi ^ (n-1) * (variance^n)^2 == 1/d^n variance = d^(-1/2) * chi^(-1/2+1/2n) - tensors = Any[randn(T, d, chi) .* variance] physical_indices = collect(1:n) - virtual_indices = collect(n+1:2n-1) - ixs = [[physical_indices[1], virtual_indices[1]]] + virtual_indices = collect(n+1:2n) + tensors = Any[(periodic ? rand(T, chi, d, chi) : rand(T, d, chi)) .* variance] + ixs = [periodic ? [virtual_indices[n], physical_indices[1], virtual_indices[1]] : [physical_indices[1], virtual_indices[1]]] for i = 2:n-1 - push!(tensors, randn(T, chi, d, chi) .* variance) + push!(tensors, rand(T, chi, d, chi) .* variance) push!(ixs, [virtual_indices[i-1], physical_indices[i], virtual_indices[i]]) end - push!(tensors, randn(T, chi, d) .* variance) - push!(ixs, [virtual_indices[n-1], physical_indices[n]]) + push!(tensors, (periodic ? rand(T, chi, d, chi) : rand(T, chi, d)) .* variance) + push!(ixs, periodic ? [virtual_indices[n-1], physical_indices[n], virtual_indices[n]] : [virtual_indices[n-1], physical_indices[n]]) size_dict = OMEinsum.get_size_dict(ixs, tensors) - nvars = 2n-1 + nvars = periodic ? 2n : 2n-1 return UAIModel( nvars, [size_dict[i] for i=1:nvars], diff --git a/test/belief.jl b/test/belief.jl index 9abd0e0..4ad0fad 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -1,12 +1,15 @@ using TensorInference, Test -using OMEinsum +using OMEinsum, LinearAlgebra @testset "process message" begin - mi = [[1, 2, 3], [2, 3, 4], [3, 4, 5]] - mo_expected = [[6, 12, 20], [3, 8, 15], [2, 6, 12]] + mi = [[1.0, 2, 3], [2.0, 3, 4], [3.0, 4, 5]] + mo_expected = [[6.0, 12, 20], [3.0, 8, 15], [2.0, 6, 12]] mo = similar.(mi) - TensorInference._process_message!(mo, mi) - @test mo == mo_expected + TensorInference._process_message!(mo, mi, false, 0) + @test all(mo .≈ mo_expected) + + TensorInference._process_message!(mo, mi, true, 0) + @test all(mo .≈ normalize!.(mo_expected, 1)) end @testset "star code" begin @@ -44,14 +47,29 @@ end @test TensorInference.initial_state(bp) isa TensorInference.BPState state, info = belief_propagate(bp) @test info.converged - @test info.iterations < 10 + @test info.iterations < 20 + mars = marginals(state) + tnet = TensorNetworkModel(mps_uai) + mars_tnet = marginals(tnet) + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-6 + end +end + +@testset "belief propagation on circle" begin + n = 10 + chi = 3 + mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true) + bp = BeliefPropgation(mps_uai) + @test TensorInference.initial_state(bp) isa TensorInference.BPState + state, info = belief_propagate(bp; max_iter=100, tol=1e-6) + @test info.converged + @test info.iterations < 100 contraction_res = TensorInference.contraction_results(state) tnet = TensorNetworkModel(mps_uai) - expected_result = probability(tnet)[] - @test all(r -> isapprox(r, expected_result), contraction_res) mars = marginals(state) mars_tnet = marginals(tnet) for v in 1:TensorInference.num_variables(bp) - @test mars[[v]] ≈ mars_tnet[[v]] + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4 end end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 35038ac..18958c0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,6 +3,9 @@ using TensorInference, Test @testset "tensor train" begin tt = random_tensor_train_uai(Float64, 5, 3) @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) + + tt = random_tensor_train_uai(Float64, 5, 3; periodic=true) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) end @testset "mps" begin From c2d30756a734b635962c0444a486d4cf9b0722af Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 1 May 2025 15:28:53 +0800 Subject: [PATCH 11/15] Change julia version to 1.10 in the compat file --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 280abc7..6d6ef21 100644 --- a/Project.toml +++ b/Project.toml @@ -31,4 +31,4 @@ PrettyTables = "2" ProblemReductions = "0.3" StatsBase = "0.34" TropicalNumbers = "0.5.4, 0.6" -julia = "1.9" +julia = "1.10" From e1a394560070b8cea319a7c21d4d5a52ea52a6b9 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 1 May 2025 15:37:10 +0800 Subject: [PATCH 12/15] add uai test --- test/belief.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/belief.jl b/test/belief.jl index 4ad0fad..88ec9c0 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -72,4 +72,28 @@ end for v in 1:TensorInference.num_variables(bp) @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4 end +end + +@testset "marginal uai2014" begin + for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)] + optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100) + evidence = Dict{Int, Int}() + model = read_model(problem) + + tn = TensorNetworkModel(model; optimizer, evidence) + mars_tnet = marginals(tn) + + code = tn.code.eins + tensors = tn.tensors + size_dict = Dict(i => d for (i, d) in enumerate(model.cards)) + + bp = BeliefPropgation(model) + state, info = belief_propagate(bp; max_iter=300, tol=1e-6) + @test info.converged + mars = marginals(state) + + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-2 + end + end end \ No newline at end of file From 4309e965bba9bdbc77f5dc9fc271f5b237a15505 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 1 May 2025 16:20:46 +0800 Subject: [PATCH 13/15] format document and fix tests --- src/belief.jl | 22 +++++++++++----------- test/belief.jl | 8 ++++++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/belief.jl b/src/belief.jl index 7e69a04..574a2e7 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -17,7 +17,7 @@ end num_tensors(bp::BeliefPropgation) = length(bp.t2v) ProblemReductions.num_variables(bp::BeliefPropgation) = length(bp.v2t) -function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T +function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where {T} # initialize the inverse mapping v2t = [Int[] for _ in 1:nvars] for (i, edge) in enumerate(t2v) @@ -33,11 +33,11 @@ $(TYPEDSIGNATURES) Construct a belief propagation object from a [`UAIModel`](@ref). """ -function BeliefPropgation(uai::UAIModel{T}) where T +function BeliefPropgation(uai::UAIModel{T}) where {T} return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors]) end -struct BPState{T, VT<:AbstractVector{T}} +struct BPState{T, VT <: AbstractVector{T}} message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages message_out::Vector{Vector{VT}} # the outgoing messages end @@ -91,12 +91,12 @@ function star_code(n::Int) ix1, ixrest = collect(1:n), [[i] for i in 1:n] ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n))) for i in 2:n - ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect(i+1:n))) + ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect((i + 1):n))) end return ne end -function initial_state(bp::BeliefPropgation{T}) where T +function initial_state(bp::BeliefPropgation{T}) where {T} size_dict = OMEinsum.get_size_dict(bp.t2v, bp.tensors) edges_vectors = Vector{Vector{T}}[] for (i, tids) in enumerate(bp.v2t) @@ -126,13 +126,13 @@ struct BPInfo converged::Bool iterations::Int end -function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int=100, tol=1e-6, damping=0.2) where T +function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int = 100, tol = 1e-6, damping = 0.2) where {T} pre_message_in = deepcopy(state.message_in) for i in 1:max_iter - collect_message!(bp, state; normalize=true) - process_message!(state; normalize=true, damping=damping) + collect_message!(bp, state; normalize = true) + process_message!(state; normalize = true, damping = damping) # check convergence - if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol=tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) + if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) return BPInfo(true, i) end pre_message_in = deepcopy(state.message_in) @@ -141,13 +141,13 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In end # if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction -function contraction_results(state::BPState{T}) where T +function contraction_results(state::BPState{T}) where {T} return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] end """ $(TYPEDSIGNATURES) """ -function marginals(state::BPState{T}) where T +function marginals(state::BPState{T}) where {T} return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in)) end \ No newline at end of file diff --git a/test/belief.jl b/test/belief.jl index 88ec9c0..150c302 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -6,10 +6,14 @@ using OMEinsum, LinearAlgebra mo_expected = [[6.0, 12, 20], [3.0, 8, 15], [2.0, 6, 12]] mo = similar.(mi) TensorInference._process_message!(mo, mi, false, 0) - @test all(mo .≈ mo_expected) + for i in 1:length(mo) + @test mo[i] ≈ mo_expected[i] atol=1e-8 + end TensorInference._process_message!(mo, mi, true, 0) - @test all(mo .≈ normalize!.(mo_expected, 1)) + for i in 1:length(mo) + @test mo[i] ≈ normalize!(mo_expected[i], 1) atol=1e-8 + end end @testset "star code" begin From 370fad79cf3405ebe73b962497e815572ba60378 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 1 May 2025 17:53:03 +0800 Subject: [PATCH 14/15] fix docstring --- src/belief.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/belief.jl b/src/belief.jl index 574a2e7..72a325a 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -116,6 +116,7 @@ Run the belief propagation algorithm, and return the final state and the informa ### Keyword Arguments - `max_iter::Int=100`: the maximum number of iterations - `tol::Float64=1e-6`: the tolerance for the convergence +- `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message """ function belief_propagate(bp::BeliefPropgation; kwargs...) state = initial_state(bp) From 8295474fbec1265994b25b3636910d14e36a56a6 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 5 May 2025 17:44:35 +0800 Subject: [PATCH 15/15] clean up --- Project.toml | 4 +- ext/TensorInferenceCUDAExt.jl | 5 +- src/Core.jl | 5 +- src/RescaledArray.jl | 7 ++- src/TensorInference.jl | 1 + src/belief.jl | 2 +- src/map.jl | 4 +- src/mar.jl | 111 +--------------------------------- src/mmap.jl | 2 +- src/sampling.jl | 8 +-- test/map.jl | 3 +- test/mar.jl | 11 +++- 12 files changed, 29 insertions(+), 134 deletions(-) diff --git a/Project.toml b/Project.toml index 6d6ef21..014b5e0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"] version = "0.5.0" [deps] -Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" @@ -21,11 +20,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" TensorInferenceCUDAExt = "CUDA" [compat] -Artifacts = "1" CUDA = "4, 5" DocStringExtensions = "0.8.6, 0.9" LinearAlgebra = "1" -OMEinsum = "0.8" +OMEinsum = "0.8.7" Pkg = "1" PrettyTables = "2" ProblemReductions = "0.3" diff --git a/ext/TensorInferenceCUDAExt.jl b/ext/TensorInferenceCUDAExt.jl index 40204fc..2f99883 100644 --- a/ext/TensorInferenceCUDAExt.jl +++ b/ext/TensorInferenceCUDAExt.jl @@ -1,7 +1,7 @@ module TensorInferenceCUDAExt using CUDA: CuArray import CUDA -import TensorInference: match_arraytype, keep_only!, onehot_like, togpu +import TensorInference: keep_only!, onehot_like, togpu function onehot_like(A::CuArray, j) mask = zero(A) @@ -9,9 +9,6 @@ function onehot_like(A::CuArray, j) return mask end -# NOTE: this interface should be in OMEinsum -match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target) - function keep_only!(x::CuArray{T}, j) where T CUDA.@allowscalar hotvalue = x[j] fill!(x, zero(T)) diff --git a/src/Core.jl b/src/Core.jl index 39866fe..2b623e6 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -190,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model. """ function OMEinsum.contraction_complexity(tn::TensorNetworkModel) return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true)))) -end - -# adapt array type with the target array type -match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target) +end \ No newline at end of file diff --git a/src/RescaledArray.jl b/src/RescaledArray.jl index 2e48e7a..8cbebec 100644 --- a/src/RescaledArray.jl +++ b/src/RescaledArray.jl @@ -46,4 +46,9 @@ end Base.size(arr::RescaledArray) = size(arr.normalized_value) Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i) -match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target)) +function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T} + return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero)) +end +# The following two APIs are required by OMEinsum +Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r) +Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value)) diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 72f17f1..a1e7482 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -8,6 +8,7 @@ $(EXPORTS) module TensorInference using OMEinsum, LinearAlgebra +using OMEinsum: CacheTree, cached_einsum using DocStringExtensions, TropicalNumbers # The Tropical GEMM support using StatsBase diff --git a/src/belief.jl b/src/belief.jl index 72a325a..eede996 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -79,7 +79,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve @assert length(vectors_out) == length(vectors_in) == ndims(t) "dimensions mismatch: $(length(vectors_out)), $(length(vectors_in)), $(ndims(t))" # TODO: speed up if needed! code = star_code(length(vectors_in)) - cost, gradient = cost_and_gradient(code, [t, vectors_in...]) + cost, gradient = cost_and_gradient(code, (t, vectors_in...)) for (o, g) in zip(vectors_out, gradient[2:end]) o .= g end diff --git a/src/map.jl b/src/map.jl index a1f5762..21a7c3c 100644 --- a/src/map.jl +++ b/src/map.jl @@ -2,7 +2,7 @@ ########### Backward tropical tensor contraction ############## # This part is copied from [`GenericTensorNetworks`](https://github.com/QuEraComputing/GenericTensorNetworks.jl). -function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy) +function OMEinsum.einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy) return backward_tropical!(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict) end @@ -55,7 +55,7 @@ Returns the largest log-probability and the most probable configuration. function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector} tensor_indices = check_queryvars(tn, [[v] for v in 1:tn.nvars]) tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false)) - logp, grads = cost_and_gradient(tn.code, tensors) + logp, grads = cost_and_gradient(tn.code, (tensors...,)) # use Array to convert CuArray to CPU arrays return content(Array(logp)[]), map(k -> haskey(tn.evidence, k) ? tn.evidence[k] : argmax(grads[tensor_indices[k]]) - 1, 1:tn.nvars) end diff --git a/src/mar.jl b/src/mar.jl index b1b65ae..3e399b4 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -12,115 +12,6 @@ function adapt_tensors(code, tensors, evidence; usecuda, rescale) end end -# ######### Inference by back propagation ############ -# `CacheTree` stores intermediate `NestedEinsum` contraction results. -# It is a tree structure that isomorphic to the contraction tree, -# `content` is the cached intermediate contraction result. -# `children` are the children of current node, e.g. tensors that are contracted to get `content`. -mutable struct CacheTree{T} - content::AbstractArray{T} - const children::Vector{CacheTree{T}} -end - -function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict) - # slicing is not supported yet. - if length(se.slicing) != 0 - @warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`." - end - return cached_einsum(se.eins, xs, size_dict) -end - -# recursively contract and cache a tensor network -function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict) - if OMEinsum.isleaf(code) - # For a leaf node, cache the input tensor - y = xs[code.tensorindex] - return CacheTree(y, CacheTree{eltype(y)}[]) - else - # For a non-leaf node, compute the einsum and cache the contraction result - caches = [cached_einsum(arg, xs, size_dict) for arg in code.args] - # `einsum` evaluates the einsum contraction, - # Its 1st argument is the contraction pattern, - # Its 2nd one is a tuple of input tensors, - # Its 3rd argument is the size dictionary (label as the key, size as the value). - y = einsum(code.eins, ntuple(i -> caches[i].content, length(caches)), size_dict) - return CacheTree(y, caches) - end -end - -# computed gradient tree by back propagation -function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} - if length(se.slicing) != 0 - @warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`." - end - return generate_gradient_tree(se.eins, cache, dy, size_dict) -end - -# recursively compute the gradients and store it into a tree. -# also known as the back-propagation algorithm. -function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} - if OMEinsum.isleaf(code) - return CacheTree(dy, CacheTree{T}[]) - else - xs = ntuple(i -> cache.children[i].content, length(cache.children)) - # `einsum_grad` is the back-propagation rule for einsum function. - # If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)` - # Then the back-propagation pass is - # ``` - # A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1) - # B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2) - # ... - # ``` - # Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`... - dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy) - return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict))) - end -end - -# a unified interface of the backward rules for real numbers and tropical numbers -function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy) - return ntuple(i -> OMEinsum.einsum_grad(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), size_dict, dy, i), length(xs)) -end - -# the main function for generating the gradient tree. -function gradient_tree(code, xs) - # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary. - size_dict = OMEinsum.get_size_dict!(getixsv(code), xs, Dict{Int, Int}()) - # forward compute and cache intermediate results. - cache = cached_einsum(code, xs, size_dict) - # initialize `y̅` as `1`. Note we always start from `L̅ := 1`. - dy = match_arraytype(typeof(cache.content), ones(eltype(cache.content), size(cache.content))) - # back-propagate - return copy(cache.content), generate_gradient_tree(code, cache, dy, size_dict) -end - -# evaluate the cost and the gradient of leaves -function cost_and_gradient(code, xs) - cost, tree = gradient_tree(code, xs) - # extract the gradients on leaves (i.e. the input tensors). - return cost, extract_leaves(code, tree) -end - -# since slicing is not supported, we forward it to NestedEinsum. -extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache) - -# extract gradients on leaf nodes. -function extract_leaves(code::NestedEinsum, cache::CacheTree) - res = Vector{Any}(undef, length(getixsv(code))) - return extract_leaves!(code, cache, res) -end - -function extract_leaves!(code, cache, res) - if OMEinsum.isleaf(code) - # extract - res[code.tensorindex] = cache.content - else - # resurse deeper - extract_leaves!.(code.args, cache.children, Ref(res)) - end - return res -end - """ $(TYPEDSIGNATURES) @@ -186,7 +77,7 @@ probabilities of the queried variables, represented by tensors. """ function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}} # sometimes, the cost can overflow, then we need to rescale the tensors during contraction. - cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale)) + cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,)) @debug "cost = $cost" ixs = OMEinsum.getixsv(tn.code) queryvars = ixs[tn.unity_tensors_idx] diff --git a/src/mmap.jl b/src/mmap.jl index 05d2deb..0d877e4 100644 --- a/src/mmap.jl +++ b/src/mmap.jl @@ -178,7 +178,7 @@ end function most_probable_config(mmap::MMAPModel; usecuda = false)::Tuple{Real, Vector} vars = get_vars(mmap) tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false)) - logp, grads = cost_and_gradient(mmap.code, tensors) + logp, grads = cost_and_gradient(mmap.code, (tensors...,)) # use Array to convert CuArray to CPU arrays return content(Array(logp)[]), map(k -> haskey(mmap.evidence, vars[k]) ? mmap.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars)) end diff --git a/src/sampling.jl b/src/sampling.jl index 588ccc3..b94d880 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -134,9 +134,9 @@ function generate_samples!(code::DynamicNestedEinsum, cache::CacheTree{T}, iy_en @assert length(iy_env) == ndims(env) if !(OMEinsum.isleaf(code)) ixs, iy = getixsv(code.eins), getiyv(code.eins) - for (subcode, child, ix) in zip(code.args, cache.children, ixs) + for (subcode, child, ix) in zip(code.args, cache.siblings, ixs) # subenv for the current child, use it to sample and update its cache - siblings = filter(x->x !== child, cache.children) + siblings = filter(x->x !== child, cache.siblings) siblings_ixs = filter(x->x !== ix, ixs) iy_subenv = batch_label ∈ ix ? ix : [ix..., batch_label] envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod(; nrepeat=1)) @@ -184,12 +184,12 @@ end function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, batch_label::L, size_dict::Dict{L}) where {T, L} OMEinsum.isleaf(ne) && return updated = false - for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins)) + for (subcode, child, ix) in zip(ne.args, cache.siblings, getixsv(ne.eins)) if any(x->x ∈ el.first, ix) updated = true child.content = _eliminate!(child.content, ix, el, batch_label) udpate_cache_tree!(subcode, child, el, batch_label, size_dict) end end - updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict)) + updated && (cache.content = einsum(ne.eins, (getfield.(cache.siblings, :content)...,), size_dict)) end \ No newline at end of file diff --git a/test/map.jl b/test/map.jl index 2ee148d..c5c95e0 100644 --- a/test/map.jl +++ b/test/map.jl @@ -20,8 +20,7 @@ end evidence=read_evidence(model), optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80)) @debug contraction_complexity(tn) - most_probable_config(tn) - @time logp, config = most_probable_config(tn) + logp, config = most_probable_config(tn) @test log_probability(tn, config) ≈ logp @test maximum_logp(tn)[] ≈ logp end diff --git a/test/mar.jl b/test/mar.jl index 1302ca7..01843eb 100644 --- a/test/mar.jl +++ b/test/mar.jl @@ -8,6 +8,13 @@ using TensorInference op = ein"ij, j -> i" @test Array(x) ≈ exp(2.0) .* [2.0, 3.0] @test op(Array(A), Array(x)) ≈ Array(op(A, x)) + + @test OMEinsum.get_output_array((A,), (2,), true) ≈ RescaledArray(0.0, [0.0, 0.0]) + @test fill!(RescaledArray(0.0, [0.0, 0.0]), 5.0) ≈ [5.0, 5.0] + + C = RescaledArray(2.0 + 1im, [2.0im 3.0; 5.0 6.0]) + @test conj(C) isa RescaledArray + @test conj(C) ≈ RescaledArray(2.0 - 1im, [-2.0im 3.0; 5.0 6.0]) end @testset "cached, rescaled contract" begin @@ -23,12 +30,12 @@ end # cached contract xs = TensorInference.adapt_tensors(tn; usecuda = false, rescale = true) size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}()) - cache = TensorInference.cached_einsum(tn.code, xs, size_dict) + cache = OMEinsum.cached_einsum(tn.code, xs, size_dict) @test cache.content isa RescaledArray @test Array(cache.content) ≈ p1 # compute marginals - ti_sol = marginals(tn) + ti_sol = marginals(tn; rescale = true) ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars @test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-5) end