Skip to content

Commit 6266d7c

Browse files
committed
Improve marginals docstring
1 parent a3cd47f commit 6266d7c

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

src/mar.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,23 @@ end
124124
"""
125125
$(TYPEDSIGNATURES)
126126
127-
Query the marginals of the variables in a [`TensorNetworkModel`](@ref).
128-
The returned value is a dictionary of variables and their marginals, where a marginal is a joint probability distribution over the associated variables.
129-
By default, the marginals of all individual variables are returned.
130-
The marginal variables to query can be specified when constructing [`TensorNetworkModel`](@ref) as its field `mars`.
131-
It will affect the contraction order of the tensor network.
127+
Queries the marginals of the variables in a [`TensorNetworkModel`](@ref). The
128+
function returns a dictionary, where the keys are the variables and the values
129+
are their respective marginals. A marginal is a probability distribution over
130+
a subset of variables, obtained by integrating or summing over the remaining
131+
variables in the model. By default, the function returns the marginals of all
132+
individual variables. To specify which marginal variables to query, set the
133+
`mars` field when constructing a [`TensorNetworkModel`](@ref). Note that
134+
the choice of marginal variables will affect the contraction order of the
135+
tensor network.
132136
133137
### Arguments
134-
- `tn`: the [`TensorNetworkModel`](@ref) to query.
135-
- `usecuda`: whether to use CUDA for tensor contraction.
136-
- `rescale`: whether to rescale the tensors during contraction.
138+
- `tn`: The [`TensorNetworkModel`](@ref) to query.
139+
- `usecuda`: Specifies whether to use CUDA for tensor contraction.
140+
- `rescale`: Specifies whether to rescale the tensors during contraction.
137141
138142
### Example
139-
The following example is from [`examples/asia/main.jl`](@ref).
143+
The following example is taken from [`examples/asia/main.jl`](@ref).
140144
141145
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
142146
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia", "asia.uai"));
@@ -168,15 +172,21 @@ Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
168172
[3, 4] => [0.05 0.45; 0.005 0.495]
169173
```
170174
171-
In this example, we first set the evidence of variable 1 to 0, then we query the marginals of all individual variables.
172-
The returned values is a dictionary, the key are query variables, and the value are the corresponding marginals.
173-
The marginals are vectors, with its entries corresponding to the probability of the variable taking the value 0 and 1, respectively.
174-
For evidence variable 1, the marginal is always `[1.0]`, since it is fixed to 0.
175+
In this example, we first set the evidence for variable 1 to 0 and then query
176+
the marginals of all individual variables. The returned dictionary has keys
177+
that correspond to the queried variables and values that represent their
178+
marginals. These marginals are vectors, with each entry corresponding to the
179+
probability of the variable taking a specific value. In this example, the
180+
possible values are 0 or 1. For the evidence variable 1, the marginal is
181+
always [1.0] since its value is fixed at 0.
182+
183+
Next, we specify the marginal variables to query as variables 2 and 3, and
184+
variables 3 and 4, respectively. The joint marginals may or may not affect the
185+
contraction time and space. In this example, the contraction space complexity
186+
increases from 2^{2.0} to 2^{5.0}, and the contraction time complexity
187+
increases from 2^{5.977} to 2^{7.781}. The output marginals are the joint
188+
probabilities of the queried variables, represented by tensors.
175189
176-
Then we set the marginal variables to query to be variable 2 and 3, and variable 3 and 4, respectively.
177-
The joint marginals may or may not increase the contraction time and space.
178-
Here, the contraction space complexity is increased from 2^2.0 to 2^5.0, and the contraction time complexity is increased from 2^5.977 to 2^7.781.
179-
The output marginals are joint probabilities of the query variables represented by tensors.
180190
"""
181191
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
182192
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.

0 commit comments

Comments
 (0)