Skip to content

Commit 98ac255

Browse files
committed
update sampling
1 parent 24c8601 commit 98ac255

File tree

2 files changed

+9
-59
lines changed

2 files changed

+9
-59
lines changed

src/sampling.jl

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,10 @@ $TYPEDFIELDS
77
88
The sampled configurations are stored in `samples`, which is a vector of vector.
99
`labels` is a vector of variable names for labeling configurations.
10-
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1110
"""
1211
struct Samples{L} <: AbstractVector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
1312
samples::Matrix{Int} # size is nvars × nsample
1413
labels::Vector{L}
15-
setmask::BitVector
16-
end
17-
function set_eliminated!(samples::Samples, eliminated_variables)
18-
for var in eliminated_variables
19-
loc = findfirst(==(var), samples.labels)
20-
samples.setmask[loc] && error("varaible `$var` is already eliminated.")
21-
samples.setmask[loc] = true
22-
end
23-
return samples
2414
end
2515
Base.getindex(s::Samples, i::Int) = view(s.samples, :, i)
2616
Base.length(s::Samples) = size(s.samples, 2)
@@ -30,16 +20,12 @@ function Base.show(io::IO, s::Samples) # display with PrettyTables
3020
PrettyTables.pretty_table(io, s.samples', header=s.labels)
3121
end
3222
num_samples(samples::Samples) = size(samples.samples, 2)
33-
eliminated_variables(samples::Samples) = samples.labels[samples.setmask]
34-
is_eliminated(samples::Samples{L}, var::L) where L = samples.setmask[findfirst(==(var), samples.labels)]
3523
function idx4labels(totalset::AbstractVector{L}, labels::AbstractVector{L})::Vector{Int} where L
3624
map(v->findfirst(==(v), totalset), labels)
3725
end
38-
idx4labels(samples::Samples{L}, lb::L) where L = findfirst(==(lb), samples.labels)
3926
function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
4027
idx = idx4labels(samples.labels, labels)
41-
@assert all(i->samples.setmask[i], idx)
42-
return samples.samples[idx, :]
28+
return view(samples.samples, idx, :)
4329
end
4430

4531
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
@@ -55,40 +41,6 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
5541
return asarray(x[idx...], x)
5642
end
5743

58-
function addbatch(samples::Samples, eliminated_variables)
59-
uniquelabels = unique!(vcat(ixs..., iy))
60-
labelmap = Dict(zip(uniquelabels, 1:length(uniquelabels)))
61-
batchdim = length(labelmap) + 1
62-
newnewixs = [Int[getindex.(Ref(labelmap), ix)..., batchdim] for ix in newixs]
63-
newnewiy = Int[getindex.(Ref(labelmap), eliminated_variables)..., batchdim]
64-
newnewxs = [get_slice(x, dimx, samples.samples[ixloc, :]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
65-
end
66-
67-
# type unstable
68-
function get_slice(x::AbstractArray{T}, slicedim, configs::AbstractMatrix) where T
69-
outdim = setdiff(1:ndims(x), slicedim)
70-
res = similar(x, [size(x, d) for d in outdim]..., size(configs, 2))
71-
return get_slice!(res, x, outdim, slicedim, configs)
72-
end
73-
74-
function get_slice!(res, x::AbstractArray{T}, outdim, slicedim, configs::AbstractMatrix) where T
75-
xstrides = strides(x)
76-
@inbounds for ci in CartesianIndices(res)
77-
idx = 1
78-
# the output dimension part
79-
for (dim, k) in zip(outdim, ci.I)
80-
idx += (k-1) * xstrides[dim]
81-
end
82-
# the sliced part
83-
batchidx = ci.I[end]
84-
for (dim, k) in zip(slicedim, view(configs, :, batchidx))
85-
idx += k * xstrides[dim]
86-
end
87-
res[ci] = x[idx]
88-
end
89-
return res
90-
end
91-
9244
"""
9345
$(TYPEDSIGNATURES)
9446
@@ -108,15 +60,13 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
10860
cache = cached_einsum(tn.code, xs, size_dict)
10961
# initialize `y̅` as the initial batch of samples.
11062
iy = getiyv(tn.code)
111-
setmask = falses(length(queryvars))
11263
idx = map(l->findfirst(==(l), queryvars), iy queryvars)
113-
setmask[idx] .= true
11464
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _Weights(vec(cache.content)), n)
11565
configs = zeros(Int, length(queryvars), n)
11666
for i=1:n
11767
configs[idx, i] .= indices[i].I .- 1
11868
end
119-
samples = Samples(configs, queryvars, setmask)
69+
samples = Samples(configs, queryvars)
12070
# back-propagate
12171
env = copy(cache.content)
12272
fill!(env, one(eltype(env)))
@@ -157,12 +107,12 @@ function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::Abstrac
157107
sample_vars = ix pool
158108
probabilities = einsum(DynamicEinCode([ix, ix], sample_vars), (child.content, subenv), size_dict)
159109
update_samples!(samples, sample_vars, probabilities)
160-
pool = setdiff(pool, sample_vars)
110+
setdiff!(pool, sample_vars)
161111

162112
# eliminate the sampled variables
163113
setindex!.(Ref(size_dict), 1, sample_vars)
164114
subsamples = subset(samples, sample_vars)[:, 1]
165-
udpate_cache_tree!(code, cache, sample_vars=>subsamples)
115+
udpate_cache_tree!(code, cache, sample_vars=>subsamples, size_dict)
166116
subenv = eliminate_dimensions(subenv, ix, sample_vars=>subsamples)
167117

168118
# recurse
@@ -178,13 +128,15 @@ function update_samples!(samples::Samples, vars::AbstractVector{L}, probabilitie
178128
eliminated_locs = idx4labels(samples.labels, vars)
179129
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
180130
samples.samples[eliminated_locs, 1] .= config.I .- 1
181-
set_eliminated!(samples, vars)
182131
end
183132

184-
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}) where {T, L}
133+
function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, size_dict::Dict{L}) where {T, L}
185134
OMEinsum.isleaf(ne) && return
135+
updated = false
186136
for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins))
137+
updated = updated || any(x->x el.first, ix)
187138
child.content = eliminate_dimensions(child.content, ix, el)
188-
udpate_cache_tree!(subcode, child, el)
139+
udpate_cache_tree!(subcode, child, el, size_dict)
189140
end
141+
updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict))
190142
end

test/sampling.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ end
9090
count(j->j==i, indices) / num_samples
9191
end
9292
probs = normalize!(real.(vec(DynamicEinCode(ixs, collect(1:4))(tensors...))), 1)
93-
#indices = StatsBase.sample(1:16, StatsBase.Weights(probs), 1000)
9493
negative_loglikelyhood(probs, samples) = -sum(log.(probs[samples]))/length(samples)
9594
entropy(probs) = -sum(probs .* log.(probs))
96-
@show distribution, probs
9795
@test negative_loglikelyhood(probs, indices) entropy(probs) atol=1e-1
9896
end

0 commit comments

Comments
 (0)