Skip to content

Commit 02cc203

Browse files
authored
Merge pull request #28 from TensorBFS/jg/sampling
Sampling configurations
2 parents fc806bf + a9379a7 commit 02cc203

File tree

8 files changed

+199
-7
lines changed

8 files changed

+199
-7
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.8'
2221
- '1'
23-
#- 'nightly'
22+
- 'nightly'
2423
os:
2524
- ubuntu-latest
2625
arch:

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1212
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1313
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
14+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
1516
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
1617

1718
[compat]
19+
Artifacts = "1"
1820
CUDA = "4"
1921
DocStringExtensions = "0.8.6, 0.9"
2022
OMEinsum = "0.7"
21-
Requires = "1"
2223
PrecompileTools = "1"
24+
Requires = "1"
25+
StatsBase = "0.34"
2326
TropicalGEMM = "0.1"
2427
TropicalNumbers = "0.5.4"
2528
julia = "1.3"

example/asia/asia.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ probability(tnet)
1010
# Get the marginal probabilities (MAR)
1111
marginals(tnet) .|> first
1212

13+
# The corresponding variables are
14+
get_vars(tnet)
15+
1316
# Set the evidence variables "X-ray" (7) to be positive.
1417
set_evidence!(instance, 7=>0)
1518

@@ -19,6 +22,9 @@ tnet = TensorNetworkModel(instance)
1922
# Get the maximum log-probabilities (MAP)
2023
maximum_logp(tnet)
2124

25+
# To sample from the probability model
26+
sample(tnet, 10)
27+
2228
# Get not only the maximum log-probability, but also the most probable conifguration
2329
# In the most probable configuration, the most probable one is the patient smoke (3) and has lung cancer (4)
2430
logp, cfg = most_probable_config(tnet)

src/TensorInference.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DocStringExtensions, TropicalNumbers
55
using Artifacts
66
# The Tropical GEMM support
77
using TropicalGEMM
8+
using StatsBase
89

910
# reexport OMEinsum functions
1011
export RescaledArray
@@ -20,6 +21,9 @@ export TensorNetworkModel, get_vars, get_cards, log_probability, probability, ma
2021
# MAP
2122
export most_probable_config, maximum_logp
2223

24+
# sampling
25+
export sample
26+
2327
# MMAP
2428
export MMAPModel
2529

@@ -29,6 +33,7 @@ include("utils.jl")
2933
include("inference.jl")
3034
include("maxprob.jl")
3135
include("mmap.jl")
36+
include("sampling.jl")
3237

3338
using Requires
3439
function __init__()
@@ -40,7 +45,7 @@ PrecompileTools.@setup_workload begin
4045
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
4146
# precompile file and potentially make loading faster.
4247
#PrecompileTools.@compile_workload begin
43-
#include("../example/asia/asia.jl")
48+
# include("../example/asia/asia.jl")
4449
#end
4550
end
4651

src/sampling.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
############ Sampling ############
2+
"""
3+
$TYPEDEF
4+
5+
### Fields
6+
$TYPEDFIELDS
7+
8+
The sampled configurations are stored in `samples`, which is a vector of vector.
9+
`labels` is a vector of variable names for labeling configurations.
10+
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
11+
"""
12+
struct Samples{L}
13+
samples::Vector{Vector{Int}}
14+
labels::Vector{L}
15+
setmask::BitVector
16+
end
17+
function setmask!(samples::Samples, eliminated_variables)
18+
for var in eliminated_variables
19+
loc = findfirst(==(var), samples.labels)
20+
samples.setmask[loc] && error("varaible `$var` is already eliminated.")
21+
samples.setmask[loc] = true
22+
end
23+
return samples
24+
end
25+
26+
idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)
27+
28+
"""
29+
$(TYPEDSIGNATURES)
30+
31+
The backward process for sampling configurations.
32+
33+
* `ixs` and `xs` are labels and tensor data for input tensors,
34+
* `iy` and `y` are labels and tensor data for the output tensor,
35+
* `samples` is the samples generated for eliminated variables,
36+
* `size_dict` is a key-value map from tensor label to dimension size.
37+
"""
38+
function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
39+
eliminated_variables = setdiff(vcat(ixs...), iy)
40+
eliminated_locs = idx4labels(samples.labels, eliminated_variables)
41+
setmask!(samples, eliminated_variables)
42+
43+
# the contraction code to get probability
44+
newiy = eliminated_variables
45+
iy_in_sample = idx4labels(samples.labels, iy)
46+
slice_y_dim = collect(1:length(iy))
47+
newixs = map(ix->setdiff(ix, iy), ixs)
48+
ix_in_sample = map(ix->idx4labels(samples.labels, ix iy), ixs)
49+
slice_xs_dim = map(ix->idx4labels(ix, ix iy), ixs)
50+
code = DynamicEinCode(newixs, newiy)
51+
52+
totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
53+
for (i, sample) in enumerate(samples.samples)
54+
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
55+
newy = get_element(y, slice_y_dim, sample[iy_in_sample])
56+
probabilities = einsum(code, (newxs...,), size_dict) / newy
57+
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
58+
# update the samples
59+
samples.samples[i][eliminated_locs] .= config.I .- 1
60+
end
61+
return samples
62+
end
63+
64+
# type unstable
65+
function get_slice(x, dim, config)
66+
asarray(x[[i dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x)
67+
end
68+
function get_element(x, dim, config)
69+
x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...]
70+
end
71+
72+
"""
73+
$(TYPEDSIGNATURES)
74+
75+
Generate samples from a tensor network based probabilistic model.
76+
Returns a vector of vector, each element being a configurations defined on `get_vars(tn)`.
77+
78+
### Arguments
79+
* `tn` is the tensor network model.
80+
* `n` is the number of samples to be returned.
81+
"""
82+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{Int}}
83+
# generate tropical tensors with its elements being log(p).
84+
xs = adapt_tensors(tn; usecuda, rescale = false)
85+
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
86+
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
87+
# forward compute and cache intermediate results.
88+
cache = cached_einsum(tn.code, xs, size_dict)
89+
# initialize `y̅` as the initial batch of samples.
90+
labels = get_vars(tn)
91+
iy = getiyv(tn.code)
92+
setmask = falses(length(labels))
93+
idx = map(l->findfirst(==(l), labels), iy)
94+
setmask[idx] .= true
95+
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
96+
configs = map(indices) do ind
97+
c=zeros(Int, length(labels))
98+
c[idx] .= ind.I .- 1
99+
c
100+
end
101+
samples = Samples(configs, labels, setmask)
102+
# back-propagate
103+
generate_samples(tn.code, cache, samples, size_dict)
104+
return samples.samples
105+
end
106+
107+
function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
108+
if !OMEinsum.isleaf(code)
109+
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
110+
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
111+
generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict))
112+
end
113+
end

src/utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ The UAI file formats are defined in:
88
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
99
"""
1010
function read_uai_file(uai_filepath; factor_eltype = Float64)
11-
1211
# Read the uai file into an array of lines
13-
rawlines = open(uai_filepath) do file
14-
readlines(file)
12+
str = open(uai_filepath) do file
13+
read(file, String)
1514
end
15+
return read_uai_string(str; factor_eltype)
16+
end
1617

18+
function read_uai_string(str; factor_eltype = Float64)
19+
rawlines = split(str, "\n")
1720
# Filter out empty lines
1821
lines = filter(!isempty, rawlines)
1922

@@ -193,5 +196,10 @@ function uai_problem_from_file(uai_filepath::String; uai_evid_filepath="", uai_m
193196
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_marginals)
194197
end
195198

199+
function uai_problem_from_string(uai::String; eltype=Float64)::UAIInstance
200+
nvars, cards, ncliques, factors = read_uai_string(uai; factor_eltype = eltype)
201+
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Vector{eltype}[])
202+
end
203+
196204
# patch to get content by broadcasting into array, while keep array size unchanged.
197205
broadcasted_content(x) = asarray(content.(x), x)

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ end
1717
@testset "MMAP" begin
1818
include("mmap.jl")
1919
end
20+
@testset "MMAP" begin
21+
include("sampling.jl")
22+
end
2023

2124
using CUDA
2225
if CUDA.functional()

test/sampling.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using TensorInference, Test
2+
3+
@testset "sampling" begin
4+
instance = TensorInference.uai_problem_from_string("""MARKOV
5+
8
6+
2 2 2 2 2 2 2 2
7+
8
8+
1 0
9+
2 1 0
10+
1 2
11+
2 3 2
12+
2 4 2
13+
3 5 3 1
14+
2 6 5
15+
3 7 5 4
16+
17+
2
18+
0.01
19+
0.99
20+
21+
4
22+
0.05 0.01
23+
0.95 0.99
24+
25+
2
26+
0.5
27+
0.5
28+
29+
4
30+
0.1 0.01
31+
0.9 0.99
32+
33+
4
34+
0.6 0.3
35+
0.4 0.7
36+
37+
8
38+
1 1 1 0
39+
0 0 0 1
40+
41+
4
42+
0.98 0.05
43+
0.02 0.95
44+
45+
8
46+
0.9 0.7 0.8 0.1
47+
0.1 0.3 0.2 0.9
48+
""")
49+
n = 10000
50+
tnet = TensorNetworkModel(instance)
51+
samples = sample(tnet, n)
52+
mars = getindex.(marginals(tnet), 2)
53+
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
54+
@test isapprox(mars, mars_sample, atol=0.05)
55+
end

0 commit comments

Comments
 (0)