Skip to content

Commit ebdbc1b

Browse files
authored
sample tree config (#25)
* samplg tree config * hamming dist test * fix doctest
1 parent 8edc945 commit ebdbc1b

File tree

8 files changed

+207
-37
lines changed

8 files changed

+207
-37
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1717
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
1818
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1919
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
20+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2122
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
2223
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -44,8 +45,7 @@ julia = "1"
4445

4546
[extras]
4647
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
47-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4848
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4949

5050
[targets]
51-
test = ["Test", "Documenter", "Random"]
51+
test = ["Test", "Documenter"]

docs/src/performancetips.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,48 @@ Benchmark shows the performance of `TropicalGEMM` is close to the theoretical op
9393
It is a sum-product expression tree to store [`ConfigEnumerator`](@ref) in a lazy style, configurations can be extracted by depth first searching the tree with the `Base.collect` method. Although it is space efficient, it is in general not easy to extract information from it.
9494
This tree structure supports directed sampling so that one can get some statistic properties from it with an intermediate effort.
9595

96+
For example, if we want to check some property of an intermediate scale graph, one can type
97+
```julia
98+
julia> graph = random_regular_graph(70, 3)
99+
100+
julia> problem = IndependentSet(graph; optimizer=TreeSA());
101+
102+
julia> tree = solve(problem, ConfigsAll(; tree_storage=true))[];
103+
16633909006371
104+
```
105+
If one wants to store these configurations, he will need a hard disk of size 256 TB!
106+
However, this sum-product binary tree structure supports efficient and unbiased direct sampling.
107+
108+
```julia
109+
samples = generate_samples(tree, 1000);
110+
```
111+
112+
With these samples, one can already compute useful properties like distribution of hamming distance (see [`hamming_distribution`](@ref)).
113+
114+
```julia
115+
julia> using UnicodePlots
116+
117+
julia> lineplot(hamming_distribution(samples, samples))
118+
┌────────────────────────────────────────┐
119+
100000 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠹⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
120+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡎⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
121+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡇⠀⢣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
122+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
123+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠸⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
124+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
125+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
126+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠃⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
127+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
128+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡞⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
129+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⣇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
130+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⢰⠁⠀⠀⠀⠀⠀⢸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
131+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⡼⠀⠀⠀⠀⠀⠀⠈⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
132+
│⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀⠀⠀⠀⠀⢳⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
133+
0 │⢀⣀⣀⣀⣀⣀⣀⣀⠎⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⠀⠀⠀⠀│
134+
└────────────────────────────────────────┘
135+
⠀0⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀80⠀
136+
```
137+
96138
(To be written.)
97139

98140
## Make use of GPUs

docs/src/ref.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ save_configs
9797
load_configs
9898
@bv_str
9999
onehotv
100+
101+
generate_samples
102+
hamming_distribution
100103
```
101104

102105
## Tensor Network

src/GraphTensorNetworks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Core: Argument
55
using TropicalNumbers
66
using OMEinsum
77
using OMEinsum: timespace_complexity, getixsv
8-
using Graphs
8+
using Graphs, Random
99

1010
# OMEinsum
1111
export timespace_complexity, timespacereadwrite_complexity, @ein_str, getixsv, getiyv
@@ -19,6 +19,7 @@ export StaticBitVector, StaticElementVector, @bv_str
1919
export is_commutative_semiring
2020
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator
2121
export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32, ExtendedTropical
22+
export generate_samples
2223

2324
# Lower level APIs
2425
export AllConfigs, SingleConfig
@@ -45,7 +46,7 @@ export is_dominating_set
4546
export solve, SizeMax, SizeMin, CountingAll, CountingMax, CountingMin, GraphPolynomial, SingleConfigMax, SingleConfigMin, ConfigsAll, ConfigsMax, ConfigsMin
4647

4748
# Utilities
48-
export save_configs, load_configs
49+
export save_configs, load_configs, hamming_distribution
4950

5051
# Visualization
5152
export show_graph, spring_layout

src/arithematics.jl

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Mods, Primes
44
using Base.Cartesian
55
import AbstractTrees: children, printnode, print_tree
66

7-
@enum TreeTag LEAF SUM PROD ZERO
7+
@enum TreeTag LEAF SUM PROD ZERO ONE
88

99
# pirate
1010
Base.abs(x::Mod) = x
@@ -415,6 +415,7 @@ Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})
415415
416416
Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
417417
and is often more memory efficient than putting the configurations in a vector.
418+
One can use [`generate_samples`](@ref) to sample configurations from this tree structure efficiently.
418419
`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
419420
420421
Fields
@@ -478,41 +479,43 @@ mutable struct TreeConfigEnumerator{N,S,C} <: AbstractSetNumber
478479
new{N,S,C}(LEAF, data)
479480
end
480481
function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C}
481-
@assert tag === ZERO
482+
@assert tag === ZERO || tag === ONE
482483
return new{N,S,C}(tag)
483484
end
484485
end
485486

486487
# AbstractTree APIs
487488
function children(t::TreeConfigEnumerator)
488-
if t.tag == ZERO || t.tag == LEAF
489+
if t.tag == ZERO || t.tag == LEAF || t.tag == ONE
489490
return typeof(t)[]
490491
else
491492
return [t.left, t.right]
492493
end
493494
end
494-
function printnode(io::IO, t::TreeConfigEnumerator)
495+
function printnode(io::IO, t::TreeConfigEnumerator{N,S,C}) where {N,S,C}
495496
if t.tag === LEAF
496497
print(io, t.data)
497498
elseif t.tag === ZERO
498499
print(io, "")
500+
elseif t.tag === ONE
501+
print(io, zero(StaticElementVector{N,S,C}))
499502
elseif t.tag === SUM
500503
print(io, "+")
501504
else # PROD
502505
print(io, "*")
503506
end
504507
end
505508

506-
Base.length(x::TreeConfigEnumerator) = _length(x, IdDict{typeof(x), Int}())
509+
Base.length(x::TreeConfigEnumerator) = _length!(x, IdDict{typeof(x), Int}())
507510

508-
function _length(x, d)
511+
function _length!(x, d)
509512
haskey(d, x) && return d[x]
510513
if x.tag === SUM
511-
l = _length(x.left, d) + _length(x.right, d)
514+
l = _length!(x.left, d) + _length!(x.right, d)
512515
d[x] = l
513516
return l
514517
elseif x.tag === PROD
515-
l = _length(x.left, d) * _length(x.right, d)
518+
l = _length!(x.left, d) * _length!(x.right, d)
516519
d[x] = l
517520
return l
518521
elseif x.tag === ZERO
@@ -525,7 +528,7 @@ end
525528
num_nodes(x::TreeConfigEnumerator) = _num_nodes(x, IdDict{typeof(x), Int}())
526529
function _num_nodes(x, d)
527530
haskey(d, x) && return 0
528-
if x.tag == ZERO
531+
if x.tag == ZERO || x.tag == ONE
529532
res = 1
530533
elseif x.tag == LEAF
531534
res = 1
@@ -545,6 +548,8 @@ Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t)
545548
function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
546549
if x.tag == ZERO
547550
return StaticElementVector{N,S,C}[]
551+
elseif x.tag == ONE
552+
return StaticElementVector{N,S,C}[zero(StaticElementVector{N,S,C})]
548553
elseif x.tag == LEAF
549554
return StaticElementVector{N,S,C}[x.data]
550555
elseif x.tag == SUM
@@ -555,15 +560,33 @@ function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
555560
end
556561

557562
function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
558-
TreeConfigEnumerator(SUM, x, y)
563+
if x.tag == ZERO
564+
return y
565+
elseif y.tag == ZERO
566+
return x
567+
else
568+
return TreeConfigEnumerator(SUM, x, y)
569+
end
559570
end
560571

561572
function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C}
562-
TreeConfigEnumerator(PROD, x, y)
573+
if x.tag == ONE
574+
return y
575+
elseif y.tag == ONE
576+
return x
577+
elseif x.tag == ZERO
578+
return x
579+
elseif y.tag == ZERO
580+
return y
581+
elseif x.tag == LEAF && y.tag == LEAF
582+
return TreeConfigEnumerator(x.data | y.data)
583+
else
584+
return TreeConfigEnumerator(PROD, x, y)
585+
end
563586
end
564587

565588
Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO)
566-
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator(zero(StaticElementVector{N,S,C}))
589+
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ONE)
567590
Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C})
568591
Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C})
569592
# todo, check siblings too?
@@ -572,13 +595,68 @@ function Base.iszero(t::TreeConfigEnumerator)
572595
iszero(t.left) && iszero(t.right)
573596
elseif t.tag == ZERO
574597
true
575-
elseif t.tag == LEAF
598+
elseif t.tag == LEAF || t.tag == ONE
576599
false
577600
else
578601
iszero(t.left) || iszero(t.right)
579602
end
580603
end
581604

605+
"""
606+
generate_samples(t::TreeConfigEnumerator, nsamples::Int)
607+
608+
Direct sampling configurations from a [`TreeConfigEnumerator`](@ref) instance.
609+
610+
Example
611+
-----------------------------
612+
```jldoctest; setup=:(using GraphTensorNetworks)
613+
julia> using Graphs
614+
615+
julia> g= smallgraph(:petersen)
616+
{10, 15} undirected simple Int64 graph
617+
618+
julia> t = solve(IndependentSet(g), ConfigsAll(; tree_storage=true))[];
619+
620+
julia> samples = generate_samples(t, 1000);
621+
622+
julia> all(s->is_independent_set(g, s), samples)
623+
true
624+
```
625+
"""
626+
function generate_samples(t::TreeConfigEnumerator{N,S,C}, nsamples::Int) where {N,S,C}
627+
# get length dict
628+
res = fill(zero(StaticElementVector{N,S,C}), nsamples)
629+
d = IdDict{typeof(t), Int}()
630+
sample_descend!(res, t, d)
631+
return res
632+
end
633+
634+
function sample_descend!(res::AbstractVector, t::TreeConfigEnumerator, d::IdDict)
635+
length(res) == 0 && return res
636+
if t.tag == LEAF
637+
res .|= Ref(t.data)
638+
elseif t.tag == SUM
639+
ratio = _length!(t.left, d)/_length!(t, d)
640+
nleft = 0
641+
for _ = 1:length(res)
642+
if rand() < ratio
643+
nleft += 1
644+
end
645+
end
646+
shuffle!(res) # shuffle the `res` to avoid biased sampling, very important.
647+
sample_descend!(view(res,1:nleft), t.left, d)
648+
sample_descend!(view(res,nleft+1:length(res)), t.right, d)
649+
elseif t.tag == PROD
650+
sample_descend!(res, t.right, d)
651+
sample_descend!(res, t.left, d)
652+
elseif t.tag == ZERO
653+
error("Meet zero when descending.")
654+
else
655+
# pass for 1
656+
end
657+
return res
658+
end
659+
582660
# A patch to make `Polynomial{ConfigEnumerator}` work
583661
function Base.:*(a::Int, y::AbstractSetNumber)
584662
a == 0 && return zero(y)
@@ -611,7 +689,8 @@ end
611689

612690
# utilities for creating onehot vectors
613691
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
614-
onehotv(::Type{TreeConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = TreeConfigEnumerator(onehotv(StaticElementVector{N,S,C}, i, v))
692+
# we treat `v == 0` specially because we want the final result not containing one leaves.
693+
onehotv(::Type{TreeConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = v == 0 ? one(TreeConfigEnumerator{N,S,C}) : TreeConfigEnumerator(onehotv(StaticElementVector{N,S,C}, i, v))
615694
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
616695
# just to make matrix transpose work
617696
Base.transpose(c::ConfigEnumerator) = c
@@ -620,7 +699,7 @@ Base.transpose(c::TreeConfigEnumerator) = c
620699
function Base.copy(c::TreeConfigEnumerator{N,S,C}) where {N,S,C}
621700
if c.tag == LEAF
622701
TreeConfigEnumerator(c.data)
623-
elseif c.tag == ZERO
702+
elseif c.tag == ZERO || c.tag == ONE
624703
TreeConfigEnumerator{N,S,C}(c.tag)
625704
else
626705
TreeConfigEnumerator(c.tag, c.left, c.right)
@@ -635,9 +714,9 @@ end
635714

636715
# to handle power of polynomials
637716
function Base.:^(x::TreeConfigEnumerator, y::Real)
638-
if y <= 0
717+
if y == 0
639718
return one(x)
640-
elseif x.tag == LEAF
719+
elseif x.tag == LEAF || x.tag == ONE || x.tag == ZERO
641720
return x
642721
else
643722
error("pow of non-leaf nodes is forbidden!")
@@ -688,6 +767,19 @@ function _x(::Type{T}; invert) where {T<:AbstractSetNumber}
688767
invert ? pre_invert_exponent(ret) : ret
689768
end
690769

770+
function _onehotv(::Type{Polynomial{BS,X}}, x, v) where {BS,X}
771+
Polynomial{BS,X}([onehotv(BS, x, v)])
772+
end
773+
function _onehotv(::Type{TruncatedPoly{K,BS,OS}}, x, v) where {K,BS,OS}
774+
TruncatedPoly{K,BS,OS}(ntuple(i->i != K ? zero(BS) : onehotv(BS, x, v), K),zero(OS))
775+
end
776+
function _onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
777+
CountingTropical{TV,BS}(zero(TV), onehotv(BS, x, v))
778+
end
779+
function _onehotv(::Type{BS}, x, v) where {BS<:AbstractSetNumber}
780+
onehotv(BS, x, v)
781+
end
782+
691783
# negate the exponents before entering the solver
692784
pre_invert_exponent(t::TruncatedPoly{K}) where K = TruncatedPoly(t.coeffs, -t.maxorder)
693785
pre_invert_exponent(t::TropicalNumbers.TropicalTypes) = inv(t)

src/configurations.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,30 @@ e.g. when the problem is [`MaximalIS`](@ref), it computes all maximal independen
9090
"""
9191
all_solutions(gp::GraphProblem; T=Float64) = solutions(gp, Polynomial{T,:x}, all=true, usecuda=false, tree_storage=false)
9292

93-
function _onehotv(::Type{Polynomial{BS,X}}, x, v) where {BS,X}
94-
Polynomial{BS,X}([onehotv(BS, x, v)])
95-
end
96-
function _onehotv(::Type{TruncatedPoly{K,BS,OS}}, x, v) where {K,BS,OS}
97-
TruncatedPoly{K,BS,OS}(ntuple(i->i != K ? zero(BS) : onehotv(BS, x, v), K),zero(OS))
93+
# NOTE: do we have more efficient way to compute it?
94+
# NOTE: doing pair-wise Hamming distance might be biased?
95+
"""
96+
hamming_distribution(S, T)
97+
98+
Compute the distribution of pair-wise Hamming distances, which is defined as:
99+
```math
100+
c(k) := \\sum_{\\sigma\\in S, \\tau\\in T} \\delta({\\rm dist}(\\sigma, \\tau), k)
101+
```
102+
where ``\\delta`` is a function that returns 1 if two arguments are equivalent, 0 otherwise,
103+
``{\\rm dist}`` is the Hamming distance function.
104+
105+
Returns the counting as a vector.
106+
"""
107+
function hamming_distribution(t1::ConfigEnumerator, t2::ConfigEnumerator)
108+
return hamming_distribution(t1.data, t2.data)
98109
end
99-
function _onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
100-
CountingTropical{TV,BS}(zero(TV), onehotv(BS, x, v))
110+
function hamming_distribution(s1::AbstractVector{StaticElementVector{N,S,C}}, s2::AbstractVector{StaticElementVector{N,S,C}}) where {N,S,C}
111+
return hamming_distribution!(zeros(Int, N+1), s1, s2)
101112
end
102-
function _onehotv(::Type{BS}, x, v) where {BS<:AbstractSetNumber}
103-
onehotv(BS, x, v)
113+
function hamming_distribution!(out::AbstractVector, s1::AbstractVector{StaticElementVector{N,S,C}}, s2::AbstractVector{StaticElementVector{N,S,C}}) where {N,S,C}
114+
@assert length(out) == N+1
115+
@inbounds for a in s1, b in s2
116+
out[count_ones(a b)+1] += 1
117+
end
118+
return out
104119
end

0 commit comments

Comments
 (0)