@@ -124,16 +124,77 @@ end
124124"""
125125$(TYPEDSIGNATURES)
126126
127- Returns the marginal probability distribution of variables.
128- One can use `get_vars(tn)` to get the full list of variables in this tensor network.
127+ Queries the marginals of the variables in a [`TensorNetworkModel`](@ref). The
128+ function returns a dictionary, where the keys are the variables and the values
129+ are their respective marginals. A marginal is a probability distribution over
130+ a subset of variables, obtained by integrating or summing over the remaining
131+ variables in the model. By default, the function returns the marginals of all
132+ individual variables. To specify which marginal variables to query, set the
133+ `mars` field when constructing a [`TensorNetworkModel`](@ref). Note that
134+ the choice of marginal variables will affect the contraction order of the
135+ tensor network.
136+
137+ ### Arguments
138+ - `tn`: The [`TensorNetworkModel`](@ref) to query.
139+ - `usecuda`: Specifies whether to use CUDA for tensor contraction.
140+ - `rescale`: Specifies whether to rescale the tensors during contraction.
141+
142+ ### Example
143+ The following example is taken from [`examples/asia/main.jl`](@ref).
144+
145+ ```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146+ julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));
147+
148+ julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
149+ TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
150+ variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
151+ contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
152+
153+ julia> marginals(tn)
154+ Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
155+ [8] => [0.450138, 0.549863]
156+ [3] => [0.5, 0.5]
157+ [1] => [1.0]
158+ [5] => [0.45, 0.55]
159+ [4] => [0.055, 0.945]
160+ [6] => [0.10225, 0.89775]
161+ [7] => [0.145092, 0.854908]
162+ [2] => [0.05, 0.95]
163+
164+ julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
165+ TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
166+ variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
167+ contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
168+
169+ julia> marginals(tn2)
170+ Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
171+ [2, 3] => [0.025 0.025; 0.475 0.475]
172+ [3, 4] => [0.05 0.45; 0.005 0.495]
173+ ```
174+
175+ In this example, we first set the evidence for variable 1 to 0 and then query
176+ the marginals of all individual variables. The returned dictionary has keys
177+ that correspond to the queried variables and values that represent their
178+ marginals. These marginals are vectors, with each entry corresponding to the
179+ probability of the variable taking a specific value. In this example, the
180+ possible values are 0 or 1. For the evidence variable 1, the marginal is
181+ always [1.0] since its value is fixed at 0.
182+
183+ Next, we specify the marginal variables to query as variables 2 and 3, and
184+ variables 3 and 4, respectively. The joint marginals may or may not affect the
185+ contraction time and space. In this example, the contraction space complexity
186+ increases from 2^{2.0} to 2^{5.0}, and the contraction time complexity
187+ increases from 2^{5.977} to 2^{7.781}. The output marginals are the joint
188+ probabilities of the queried variables, represented by tensors.
189+
129190"""
130- function marginals (tn:: TensorNetworkModel ; usecuda = false , rescale = true ):: Vector
191+ function marginals (tn:: TensorNetworkModel ; usecuda = false , rescale = true ):: Dict{ Vector{Int}}
131192 # sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
132193 cost, grads = cost_and_gradient (tn. code, adapt_tensors (tn; usecuda, rescale))
133194 @debug " cost = $cost "
134195 if rescale
135- return LinearAlgebra. normalize! .(getfield .(grads[1 : length (tn. mars)], :normalized_value ), 1 )
196+ return Dict ( zip (tn . mars, LinearAlgebra. normalize! .(getfield .(grads[1 : length (tn. mars)], :normalized_value ), 1 )) )
136197 else
137- return LinearAlgebra. normalize! .(grads[1 : length (tn. mars)], 1 )
198+ return Dict ( zip (tn . mars, LinearAlgebra. normalize! .(grads[1 : length (tn. mars)], 1 )) )
138199 end
139200end
0 commit comments