Skip to content

Commit 17ab0ad

Browse files
committed
save
1 parent 6170325 commit 17ab0ad

File tree

1 file changed

+60
-66
lines changed

1 file changed

+60
-66
lines changed

src/sampling.jl

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -42,53 +42,6 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
4242
return samples.samples[idx, :]
4343
end
4444

45-
"""
46-
$(TYPEDSIGNATURES)
47-
48-
The backward process for sampling configurations.
49-
50-
### Arguments
51-
* `code` is the contraction code in the current step,
52-
* `env` is the environment tensor,
53-
* `samples` is the samples generated for eliminated variables,
54-
* `size_dict` is a key-value map from tensor label to dimension size.
55-
"""
56-
function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(env), samples::Samples, size_dict)
57-
ixs, iy = getixsv(code), getiyv(code)
58-
el = setdiff(vcat(ixs...), iy) samples.labels
59-
60-
# get probability
61-
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
62-
el_prev = eliminated_variables(samples)
63-
@show el_prev=>subset(samples, el_prev)[:,1]
64-
xs = [eliminate_dimensions(x, ix, el_prev=>subset(samples, el_prev)[:,1]) for (ix, x) in zip(ixs, xs)]
65-
probabilities = einsum(prob_code, (xs..., env), size_dict)
66-
@show el
67-
@show normalize(real.(vec(probabilities)), 1)
68-
69-
# sample from the probability tensor
70-
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
71-
eliminated_locs = idx4labels(samples.labels, el)
72-
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
73-
@show eliminated_locs, config.I .- 1
74-
samples.samples[eliminated_locs, 1] .= config.I .- 1
75-
76-
# eliminate the sampled variables
77-
set_eliminated!(samples, el)
78-
setindex!.(Ref(size_dict), 1, el)
79-
sub = subset(samples, el)[:, 1]
80-
@show ixs, el=>sub
81-
xs = [eliminate_dimensions(x, ix, el=>sub) for (ix, x) in zip(ixs, xs)]
82-
83-
# update environment
84-
envs = map(1:length(ixs)) do i
85-
rest = setdiff(1:length(ixs), i)
86-
code = optimize_code(EinCode([ixs[rest]..., iy], ixs[i]), size_dict, GreedyMethod(; nrepeat=1))
87-
einsum(code, (xs[rest]..., env), size_dict)
88-
end
89-
@show envs
90-
end
91-
9245
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
9346
idx = ntuple(N) do i
9447
if ix[i] el.first
@@ -167,7 +120,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
167120
# back-propagate
168121
env = copy(cache.content)
169122
fill!(env, one(eltype(env)))
170-
generate_samples!(tn.code, cache, env, samples, size_dict)
123+
generate_samples!(tn.code, cache, env, samples, samples.labels, size_dict)
171124
# set evidence variables
172125
for (k, v) in tn.evidence
173126
idx = findfirst(==(k), samples.labels)
@@ -181,30 +134,71 @@ function _Weights(x::AbstractArray{<:Complex})
181134
return Weights(real.(x))
182135
end
183136

184-
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, size_dict::Dict) where {T}
137+
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples, pool, size_dict::Dict) where {T}
185138
# slicing is not supported yet.
186139
if length(se.slicing) != 0
187140
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
188141
end
189-
return generate_samples!(se.eins, cache, env, samples, size_dict)
142+
return generate_samples!(se.eins, cache, env, samples, pool, size_dict)
190143
end
191-
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples, size_dict::Dict) where {T}
192-
@info "@"
144+
145+
# pool is a vector of labels that are not eliminated yet.
146+
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples{L}, pool::Vector{L}, size_dict::Dict) where {T, L}
193147
if !(OMEinsum.isleaf(code))
194-
@info "non-leaf node"
195-
@show env
196-
xs = ntuple(i -> cache.children[i].content, length(cache.children))
197-
envs = backward_sampling!(code.eins, xs, env, samples, size_dict)
198-
@show envs
199-
fucks = map(1:length(code.args)) do k
200-
@info k
201-
generate_samples!(code.args[k], cache.children[k], envs[k], samples, size_dict)
202-
return "fuck"
148+
ixs, iy = getixsv(code), getiyv(code)
149+
for (subcode, child, ix) in zip(code.args, cache.children, ixs)
150+
# subenv for the current child, use it to sample and update its cache
151+
siblings = filter(x->x !== child, cache.children)
152+
siblings_ixs = filter(x->x !== ix, ixs)
153+
envcode = optimize_code(EinCode([siblings_ixs..., iy], ix), size_dict, GreedyMethod(; nrepeat=1))
154+
subenv = einsum(envcode, (getfield.(siblings, :content)..., env), size_dict)
155+
156+
# sample
157+
sample_vars = ix pool
158+
update_samples!(child.content, subenv, samples, ix, sample_vars, size_dict)
159+
160+
generate_samples!(subcode, child, subenv, samples, setdiff(pool, sample_vars), size_dict)
203161
end
204-
@info fucks
205-
return
206-
else
207-
@info "leaf node"
208-
return
162+
end
163+
end
164+
165+
"""
166+
$(TYPEDSIGNATURES)
167+
168+
The backward process for sampling configurations.
169+
170+
### Arguments
171+
* `code` is the contraction code in the current step,
172+
* `env` is the environment tensor,
173+
* `samples` is the samples generated for eliminated variables,
174+
* `size_dict` is a key-value map from tensor label to dimension size.
175+
"""
176+
function update_samples!(code::EinCode, @nospecialize(xs::Tuple), @nospecialize(env), samples::Samples, size_dict)
177+
ixs, iy = getixsv(code), getiyv(code)
178+
el = setdiff(vcat(ixs...), iy) samples.labels
179+
180+
# get probability
181+
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
182+
el_prev = eliminated_variables(samples)
183+
xs = [eliminate_dimensions(x, ix, el_prev=>subset(samples, el_prev)[:,1]) for (ix, x) in zip(ixs, xs)]
184+
probabilities = einsum(prob_code, (xs..., env), size_dict)
185+
186+
# sample from the probability tensor
187+
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
188+
eliminated_locs = idx4labels(samples.labels, el)
189+
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
190+
samples.samples[eliminated_locs, 1] .= config.I .- 1
191+
192+
# eliminate the sampled variables
193+
set_eliminated!(samples, el)
194+
setindex!.(Ref(size_dict), 1, el)
195+
sub = subset(samples, el)[:, 1]
196+
xs = [eliminate_dimensions(x, ix, el=>sub) for (ix, x) in zip(ixs, xs)]
197+
198+
# update environment
199+
return map(1:length(ixs)) do i
200+
rest = setdiff(1:length(ixs), i)
201+
code = optimize_code(EinCode([ixs[rest]..., iy], ixs[i]), size_dict, GreedyMethod(; nrepeat=1))
202+
einsum(code, (xs[rest]..., env), size_dict)
209203
end
210204
end

0 commit comments

Comments
 (0)