Skip to content

Commit f0896af

Browse files
committed
save
1 parent 08a1ce4 commit f0896af

File tree

4 files changed

+19
-6
lines changed

4 files changed

+19
-6
lines changed

example/asia/asia.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ tnet = TensorNetworkModel(instance)
2222
# Get the maximum log-probabilities (MAP)
2323
maximum_logp(tnet)
2424

25+
# To sample from the probability model
26+
sample(tnet, 10)
27+
2528
# Get not only the maximum log-probability, but also the most probable conifguration
2629
# In the most probable configuration, the most probable one is the patient smoke (3) and has lung cancer (4)
2730
logp, cfg = most_probable_config(tnet)

src/inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
4949
end
5050

5151
# computed gradient tree by back propagation
52-
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy, size_dict::Dict) where {T}
52+
function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
5353
if length(se.slicing) != 0
5454
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
5555
end
@@ -58,7 +58,7 @@ end
5858

5959
# recursively compute the gradients and store it into a tree.
6060
# also known as the back-propagation algorithm.
61-
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy, size_dict::Dict) where {T}
61+
function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T}
6262
if OMEinsum.isleaf(code)
6363
return CacheTree(dy, CacheTree{T}[])
6464
else

src/sampling.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
############ Sampling ############
2+
"""
3+
$TYPEDEF
4+
5+
### Fields
6+
$TYPEDFIELDS
7+
8+
The sampled configurations are stored in `samples`, which is a vector of vector.
9+
`labels` is a vector of variable names for labeling configurations.
10+
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
11+
"""
212
struct Samples{L}
313
samples::Vector{Vector{Int}}
414
labels::Vector{L}
@@ -18,7 +28,7 @@ idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
1828
"""
1929
$(TYPEDSIGNATURES)
2030
21-
The backward rule for tropical einsum.
31+
The backward process for sampling configurations.
2232
2333
* `ixs` and `xs` are labels and tensor data for input tensors,
2434
* `iy` and `y` are labels and tensor data for the output tensor,
@@ -72,7 +82,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Samples
7282
# forward compute and cache intermediate results.
7383
cache = cached_einsum(tn.code, xs, size_dict)
7484
# initialize `y̅` as the initial batch of samples.
75-
labels = OMEinsum.uniquelabels(tn.code)
85+
labels = get_vars(tn)
7686
iy = getiyv(tn.code)
7787
setmask = falses(length(labels))
7888
idx = map(l->findfirst(==(l), labels), iy)
@@ -86,7 +96,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Samples
8696
samples = Samples(configs, labels, setmask)
8797
# back-propagate
8898
generate_samples(tn.code, cache, samples, size_dict)
89-
return samples
99+
return samples.samples
90100
end
91101

92102
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}

test/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ using TensorInference, Test
5050
tnet = TensorNetworkModel(instance)
5151
samples = sample(tnet, n)
5252
mars = getindex.(marginals(tnet), 2)
53-
mars_sample = [count(s->s[k]==(1), samples.samples) for k=1:8] ./ n
53+
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
5454
@test isapprox(mars, mars_sample, atol=0.05)
5555
end

0 commit comments

Comments
 (0)