@@ -103,17 +103,19 @@ function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::Abstrac
103103 envcode = optimize_code (EinCode ([siblings_ixs... , iy], ix), size_dict, GreedyMethod (; nrepeat= 1 ))
104104 subenv = einsum (envcode, (getfield .(siblings, :content )... , env), size_dict)
105105
106- # get samples
106+ # generate samples
107107 sample_vars = ix ∩ pool
108- probabilities = einsum (DynamicEinCode ([ix, ix], sample_vars), (child. content, subenv), size_dict)
109- update_samples! (samples, sample_vars, probabilities)
110- setdiff! (pool, sample_vars)
108+ if ! isempty (sample_vars)
109+ probabilities = einsum (DynamicEinCode ([ix, ix], sample_vars), (child. content, subenv), size_dict)
110+ update_samples! (samples, sample_vars, probabilities)
111+ setdiff! (pool, sample_vars)
111112
112- # eliminate the sampled variables
113- setindex! .(Ref (size_dict), 1 , sample_vars)
114- subsamples = subset (samples, sample_vars)[:, 1 ]
115- udpate_cache_tree! (code, cache, sample_vars=> subsamples, size_dict)
116- subenv = eliminate_dimensions (subenv, ix, sample_vars=> subsamples)
113+ # eliminate the sampled variables
114+ setindex! .(Ref (size_dict), 1 , sample_vars)
115+ subsamples = subset (samples, sample_vars)[:, 1 ]
116+ udpate_cache_tree! (code, cache, sample_vars=> subsamples, size_dict)
117+ subenv = eliminate_dimensions (subenv, ix, sample_vars=> subsamples)
118+ end
117119
118120 # recurse
119121 generate_samples! (subcode, child, subenv, samples, setdiff (pool, sample_vars), size_dict)
@@ -134,9 +136,11 @@ function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:Ab
134136 OMEinsum. isleaf (ne) && return
135137 updated = false
136138 for (subcode, child, ix) in zip (ne. args, cache. children, getixsv (ne. eins))
137- updated = updated || any (x-> x ∈ el. first, ix)
138- child. content = eliminate_dimensions (child. content, ix, el)
139- udpate_cache_tree! (subcode, child, el, size_dict)
139+ if any (x-> x ∈ el. first, ix)
140+ updated = true
141+ child. content = eliminate_dimensions (child. content, ix, el)
142+ udpate_cache_tree! (subcode, child, el, size_dict)
143+ end
140144 end
141145 updated && (cache. content = einsum (ne. eins, (getfield .(cache. children, :content )... ,), size_dict))
142146end
0 commit comments