1616# `CacheTree` stores intermediate `NestedEinsum` contraction results.
1717# It is a tree structure that isomorphic to the contraction tree,
1818# `content` is the cached intermediate contraction result.
19- # `siblings ` are the siblings of current node.
20- struct CacheTree{T}
19+ # `children ` are the children of current node, e.g. tensors that are contracted to get `content` .
20+ mutable struct CacheTree{T}
2121 content:: AbstractArray{T}
22- siblings :: Vector{CacheTree{T}}
22+ const children :: Vector{CacheTree{T}}
2323end
2424
2525function cached_einsum (se:: SlicedEinsum , @nospecialize (xs), size_dict)
@@ -62,7 +62,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
6262 if OMEinsum. isleaf (code)
6363 return CacheTree (dy, CacheTree{T}[])
6464 else
65- xs = ntuple (i -> cache. siblings [i]. content, length (cache. siblings ))
65+ xs = ntuple (i -> cache. children [i]. content, length (cache. children ))
6666 # `einsum_grad` is the back-propagation rule for einsum function.
6767 # If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
6868 # Then the back-propagation pass is
@@ -73,7 +73,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
7373 # ```
7474 # Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
7575 dxs = einsum_backward_rule (code. eins, xs, cache. content, size_dict, dy)
76- return CacheTree (dy, generate_gradient_tree .(code. args, cache. siblings , dxs, Ref (size_dict)))
76+ return CacheTree (dy, generate_gradient_tree .(code. args, cache. children , dxs, Ref (size_dict)))
7777 end
7878end
7979
@@ -116,7 +116,7 @@ function extract_leaves!(code, cache, res)
116116 res[code. tensorindex] = cache. content
117117 else
118118 # resurse deeper
119- extract_leaves! .(code. args, cache. siblings , Ref (res))
119+ extract_leaves! .(code. args, cache. children , Ref (res))
120120 end
121121 return res
122122end
@@ -145,10 +145,7 @@ The following example is taken from [`examples/asia-network/main.jl`](https://te
145145```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146146julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia-network", "model.uai"));
147147
148- julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
149- TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
150- variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
151- contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
148+ julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0));
152149
153150julia> marginals(tn)
154151Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
@@ -161,10 +158,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
161158 [7] => [0.145092, 0.854908]
162159 [2] => [0.05, 0.95]
163160
164- julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
165- TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
166- variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
167- contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
161+ julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]);
168162
169163julia> marginals(tn2)
170164Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
0 commit comments