Skip to content

Commit cb59646

Browse files
committed
update
1 parent 98ac255 commit cb59646

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/sampling.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,19 @@ function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::Abstrac
103103
envcode = optimize_code(EinCode([siblings_ixs..., iy], ix), size_dict, GreedyMethod(; nrepeat=1))
104104
subenv = einsum(envcode, (getfield.(siblings, :content)..., env), size_dict)
105105

106-
# get samples
106+
# generate samples
107107
sample_vars = ix pool
108-
probabilities = einsum(DynamicEinCode([ix, ix], sample_vars), (child.content, subenv), size_dict)
109-
update_samples!(samples, sample_vars, probabilities)
110-
setdiff!(pool, sample_vars)
108+
if !isempty(sample_vars)
109+
probabilities = einsum(DynamicEinCode([ix, ix], sample_vars), (child.content, subenv), size_dict)
110+
update_samples!(samples, sample_vars, probabilities)
111+
setdiff!(pool, sample_vars)
111112

112-
# eliminate the sampled variables
113-
setindex!.(Ref(size_dict), 1, sample_vars)
114-
subsamples = subset(samples, sample_vars)[:, 1]
115-
udpate_cache_tree!(code, cache, sample_vars=>subsamples, size_dict)
116-
subenv = eliminate_dimensions(subenv, ix, sample_vars=>subsamples)
113+
# eliminate the sampled variables
114+
setindex!.(Ref(size_dict), 1, sample_vars)
115+
subsamples = subset(samples, sample_vars)[:, 1]
116+
udpate_cache_tree!(code, cache, sample_vars=>subsamples, size_dict)
117+
subenv = eliminate_dimensions(subenv, ix, sample_vars=>subsamples)
118+
end
117119

118120
# recurse
119121
generate_samples!(subcode, child, subenv, samples, setdiff(pool, sample_vars), size_dict)
@@ -134,9 +136,11 @@ function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:Ab
134136
OMEinsum.isleaf(ne) && return
135137
updated = false
136138
for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins))
137-
updated = updated || any(x->x el.first, ix)
138-
child.content = eliminate_dimensions(child.content, ix, el)
139-
udpate_cache_tree!(subcode, child, el, size_dict)
139+
if any(x->x el.first, ix)
140+
updated = true
141+
child.content = eliminate_dimensions(child.content, ix, el)
142+
udpate_cache_tree!(subcode, child, el, size_dict)
143+
end
140144
end
141145
updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict))
142146
end

0 commit comments

Comments
 (0)