Skip to content

Commit 9c43716

Browse files
committed
update sampling with evidence
1 parent a9379a7 commit 9c43716

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/sampling.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{
101101
samples = Samples(configs, labels, setmask)
102102
# back-propagate
103103
generate_samples(tn.code, cache, samples, size_dict)
104+
# set evidence variables
105+
for (k, v) in tn.fixedvertices
106+
idx = findfirst(==(k), labels)
107+
for c in samples.samples
108+
c[idx] = v
109+
end
110+
end
104111
return samples.samples
105112
end
106113

test/sampling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,19 @@ using TensorInference, Test
4646
0.9 0.7 0.8 0.1
4747
0.1 0.3 0.2 0.9
4848
""")
49+
# general sampling
4950
n = 10000
5051
tnet = TensorNetworkModel(instance)
5152
samples = sample(tnet, n)
5253
mars = getindex.(marginals(tnet), 2)
5354
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
5455
@test isapprox(mars, mars_sample, atol=0.05)
56+
57+
# fix the evidence
58+
set_evidence!(instance, 7=>1)
59+
tnet = TensorNetworkModel(instance)
60+
samples = sample(tnet, n)
61+
mars = getindex.(marginals(tnet), 1)
62+
mars_sample = [count(s->s[k]==(0), samples) for k=1:8] ./ n
63+
@test isapprox([mars[1:6]..., mars[8]], [mars_sample[1:6]..., mars_sample[8]], atol=0.05)
5564
end

0 commit comments

Comments
 (0)