Skip to content

Commit bad349e

Browse files
committed
fix sampling for sliced einsum
1 parent 85b1c66 commit bad349e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sampling.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::AbstractMatrix
128128
return samples.samples
129129
end
130130

131+
function generate_samples(se::SlicedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
132+
# slicing is not supported yet.
133+
if length(se.slicing) != 0
134+
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
135+
end
136+
return generate_samples(se.eins, cache, samples, size_dict)
137+
end
131138
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
132139
if !OMEinsum.isleaf(code)
133140
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))

test/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ using TensorInference, Test
5555

5656
# fix the evidence
5757
set_evidence!(instance, 7=>1)
58-
tnet = TensorNetworkModel(instance)
58+
tnet = TensorNetworkModel(instance, optimizer=TreeSA())
5959
samples = sample(tnet, n)
6060
mars = getindex.(marginals(tnet), 1)
6161
mars_sample = [count(i->samples[k, i]==(0), axes(samples, 2)) for k=1:8] ./ n

0 commit comments

Comments
 (0)