Skip to content

Commit 24c8601

Browse files
committed
update
1 parent 17ab0ad commit 24c8601

File tree

2 files changed

+29
-43
lines changed

2 files changed

+29
-43
lines changed

src/mar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ end
1717
# It is a tree structure that isomorphic to the contraction tree,
1818
# `content` is the cached intermediate contraction result.
1919
# `children` are the children of current node, e.g. tensors that are contracted to get `content`.
20-
struct CacheTree{T}
20+
mutable struct CacheTree{T}
2121
content::AbstractArray{T}
22-
children::Vector{CacheTree{T}}
22+
const children::Vector{CacheTree{T}}
2323
end
2424

2525
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)

src/sampling.jl

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
4343
end
4444

4545
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
46+
@assert length(ix) == N
4647
idx = ntuple(N) do i
4748
if ix[i] el.first
4849
k = el.second[findfirst(==(ix[i]), el.first)] + 1
@@ -51,7 +52,6 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
5152
1:size(x, i)
5253
end
5354
end
54-
@show idx
5555
return asarray(x[idx...], x)
5656
end
5757

@@ -143,62 +143,48 @@ function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractA
143143
end
144144

145145
# pool is a vector of labels that are not eliminated yet.
146-
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples{L}, pool::Vector{L}, size_dict::Dict) where {T, L}
146+
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples{L}, pool::Vector{L}, size_dict::Dict{L}) where {T, L}
147147
if !(OMEinsum.isleaf(code))
148-
ixs, iy = getixsv(code), getiyv(code)
148+
ixs, iy = getixsv(code.eins), getiyv(code.eins)
149149
for (subcode, child, ix) in zip(code.args, cache.children, ixs)
150150
# subenv for the current child, use it to sample and update its cache
151151
siblings = filter(x->x !== child, cache.children)
152152
siblings_ixs = filter(x->x !== ix, ixs)
153153
envcode = optimize_code(EinCode([siblings_ixs..., iy], ix), size_dict, GreedyMethod(; nrepeat=1))
154154
subenv = einsum(envcode, (getfield.(siblings, :content)..., env), size_dict)
155155

156-
# sample
156+
# get samples
157157
sample_vars = ix pool
158-
update_samples!(child.content, subenv, samples, ix, sample_vars, size_dict)
158+
probabilities = einsum(DynamicEinCode([ix, ix], sample_vars), (child.content, subenv), size_dict)
159+
update_samples!(samples, sample_vars, probabilities)
160+
pool = setdiff(pool, sample_vars)
159161

162+
# eliminate the sampled variables
163+
setindex!.(Ref(size_dict), 1, sample_vars)
164+
subsamples = subset(samples, sample_vars)[:, 1]
165+
udpate_cache_tree!(code, cache, sample_vars=>subsamples)
166+
subenv = eliminate_dimensions(subenv, ix, sample_vars=>subsamples)
167+
168+
# recurse
160169
generate_samples!(subcode, child, subenv, samples, setdiff(pool, sample_vars), size_dict)
161170
end
162171
end
163172
end
164173

165-
"""
166-
$(TYPEDSIGNATURES)
167-
168-
The backward process for sampling configurations.
169-
170-
### Arguments
171-
* `code` is the contraction code in the current step,
172-
* `env` is the environment tensor,
173-
* `samples` is the samples generated for eliminated variables,
174-
* `size_dict` is a key-value map from tensor label to dimension size.
175-
"""
176-
function update_samples!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(env), samples::Samples, size_dict)
177-
ixs, iy = getixsv(code), getiyv(code)
178-
el = setdiff(vcat(ixs...), iy) samples.labels
179-
180-
# get probability
181-
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
182-
el_prev = eliminated_variables(samples)
183-
xs = [eliminate_dimensions(x, ix, el_prev=>subset(samples, el_prev)[:,1]) for (ix, x) in zip(ixs, xs)]
184-
probabilities = einsum(prob_code, (xs..., env), size_dict)
185-
186-
# sample from the probability tensor
187-
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
188-
eliminated_locs = idx4labels(samples.labels, el)
174+
# probabilities is a tensor of probabilities for each variable in `vars`.
175+
function update_samples!(samples::Samples, vars::AbstractVector{L}, probabilities::AbstractArray{T, N}) where {L, T, N}
176+
@assert length(vars) == N
177+
totalset = CartesianIndices(probabilities)
178+
eliminated_locs = idx4labels(samples.labels, vars)
189179
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
190180
samples.samples[eliminated_locs, 1] .= config.I .- 1
181+
set_eliminated!(samples, vars)
182+
end
191183

192-
# eliminate the sampled variables
193-
set_eliminated!(samples, el)
194-
setindex!.(Ref(size_dict), 1, el)
195-
sub = subset(samples, el)[:, 1]
196-
xs = [eliminate_dimensions(x, ix, el=>sub) for (ix, x) in zip(ixs, xs)]
197-
198-
# update environment
199-
return map(1:length(ixs)) do i
200-
rest = setdiff(1:length(ixs), i)
201-
code = optimize_code(EinCode([ixs[rest]..., iy], ixs[i]), size_dict, GreedyMethod(; nrepeat=1))
202-
einsum(code, (xs[rest]..., env), size_dict)
184+
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}) where {T, L}
185+
OMEinsum.isleaf(ne) && return
186+
for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins))
187+
child.content = eliminate_dimensions(child.content, ix, el)
188+
udpate_cache_tree!(subcode, child, el)
203189
end
204-
end
190+
end

0 commit comments

Comments
 (0)