Skip to content

Commit c0f080e

Browse files
authored
refactor sum product tree (#26)
1 parent ebdbc1b commit c0f080e

File tree

4 files changed

+71
-52
lines changed

4 files changed

+71
-52
lines changed

docs/src/performancetips.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ The GEMM routine can speed up the computation on CPU for one order, with multi-t
8989
Benchmark shows the performance of `TropicalGEMM` is close to the theoretical optimal value.
9090

9191
## Sum product representation for configurations
92-
[`TreeConfigEnumerator`](@ref) can save a lot memory for you to store exponential number of configurations in polynomial space.
92+
[`TreeConfigEnumerator`](@ref) (an alias of [`SumProductTree`](@ref) with [`StaticElementVector`](@ref) as its data type) can save a lot memory for you to store exponential number of configurations in polynomial space.
9393
It is a sum-product expression tree to store [`ConfigEnumerator`](@ref) in a lazy style, configurations can be extracted by depth first searching the tree with the `Base.collect` method. Although it is space efficient, it is in general not easy to extract information from it.
9494
This tree structure supports directed sampling so that one can get some statistic properties from it with an intermediate effort.
9595

docs/src/ref.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ Polynomials.Polynomial
8686
TruncatedPoly
8787
Max2Poly
8888
ConfigEnumerator
89+
SumProductTree
8990
TreeConfigEnumerator
9091
ConfigSampler
9192
```

src/GraphTensorNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ export estimate_memory
1717
# Algebras
1818
export StaticBitVector, StaticElementVector, @bv_str
1919
export is_commutative_semiring
20-
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator
20+
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod
21+
export ConfigEnumerator, onehotv, ConfigSampler, SumProductTree, TreeConfigEnumerator
2122
export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32, ExtendedTropical
2223
export generate_samples
2324

src/arithematics.jl

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -408,30 +408,27 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
408408
Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C})
409409
Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})
410410

411-
# tree config enumerator
412-
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
413411
"""
414-
TreeConfigEnumerator{N,S,C} <: AbstractSetNumber
412+
SumProductTree{ET} <: AbstractSetNumber
415413
416414
Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
417415
and is often more memory efficient than putting the configurations in a vector.
418416
One can use [`generate_samples`](@ref) to sample configurations from this tree structure efficiently.
419-
`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
420417
421418
Fields
422419
-----------------------
423-
* `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
420+
* `tag` is one of `ZERO`, `ONE`, `LEAF`, `SUM`, `PROD`.
424421
* `data` is the element stored in a `LEAF` node.
425422
* `left` and `right` are two operands of a `SUM` or `PROD` node.
426423
427424
Example
428425
------------------------
429426
```jldoctest; setup=:(using GraphTensorNetworks)
430-
julia> s = TreeConfigEnumerator(bv"00111")
427+
julia> s = SumProductTree(bv"00111")
431428
00111
432429
433430
434-
julia> q = TreeConfigEnumerator(bv"10000")
431+
julia> q = SumProductTree(bv"10000")
435432
10000
436433
437434
@@ -469,44 +466,64 @@ julia> one(s)
469466
470467
```
471468
"""
472-
mutable struct TreeConfigEnumerator{N,S,C} <: AbstractSetNumber
469+
mutable struct SumProductTree{ET} <: AbstractSetNumber
473470
tag::TreeTag
474-
data::StaticElementVector{N,S,C}
475-
left::TreeConfigEnumerator{N,S,C}
476-
right::TreeConfigEnumerator{N,S,C}
477-
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = new{N,S,C}(tag, zero(StaticElementVector{N,S,C}), left, right)
478-
function TreeConfigEnumerator(data::StaticElementVector{N,S,C}) where {N,S,C}
479-
new{N,S,C}(LEAF, data)
480-
end
481-
function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C}
471+
data::ET
472+
left::SumProductTree{ET}
473+
right::SumProductTree{ET}
474+
# zero(ET) can be undef
475+
function SumProductTree(tag::TreeTag, left::SumProductTree{ET}, right::SumProductTree{ET}) where {ET}
476+
res = new{ET}(tag)
477+
res.left = left
478+
res.right = right
479+
return res
480+
end
481+
function SumProductTree(data::ET) where ET
482+
return new{ET}(LEAF, data)
483+
end
484+
function SumProductTree{ET}(tag::TreeTag) where {ET}
482485
@assert tag === ZERO || tag === ONE
483-
return new{N,S,C}(tag)
486+
return new{ET}(tag)
484487
end
485488
end
489+
# these two interfaces must be implemented in order to collect elements
490+
_data_mul(x::StaticElementVector, y::StaticElementVector) = x | y
491+
_data_one(::Type{T}) where T<:StaticElementVector = zero(T) # NOTE: might be optional
492+
493+
"""
494+
TreeConfigEnumerator{N,S,C}
495+
496+
An alias for [`SumProductTree`](@ref)`{StaticElementVector{N, S, C}}`,
497+
which is a useful element type for configuration enumeration.
498+
"""
499+
const TreeConfigEnumerator{N,S,C} = SumProductTree{StaticElementVector{N,S,C}}
500+
TreeConfigEnumerator(data::StaticElementVector) = SumProductTree(data)
501+
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = SumProductTree(tag, left, right)
486502

487503
# AbstractTree APIs
488-
function children(t::TreeConfigEnumerator)
504+
function children(t::SumProductTree)
489505
if t.tag == ZERO || t.tag == LEAF || t.tag == ONE
490506
return typeof(t)[]
491507
else
492508
return [t.left, t.right]
493509
end
494510
end
495-
function printnode(io::IO, t::TreeConfigEnumerator{N,S,C}) where {N,S,C}
511+
function printnode(io::IO, t::SumProductTree{ET}) where {ET}
496512
if t.tag === LEAF
497513
print(io, t.data)
498514
elseif t.tag === ZERO
499515
print(io, "")
500516
elseif t.tag === ONE
501-
print(io, zero(StaticElementVector{N,S,C}))
517+
print(io, _data_one(ET))
502518
elseif t.tag === SUM
503519
print(io, "+")
504520
else # PROD
505521
print(io, "*")
506522
end
507523
end
508524

509-
Base.length(x::TreeConfigEnumerator) = _length!(x, IdDict{typeof(x), Int}())
525+
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
526+
Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}())
510527

511528
function _length!(x, d)
512529
haskey(d, x) && return d[x]
@@ -525,7 +542,7 @@ function _length!(x, d)
525542
end
526543
end
527544

528-
num_nodes(x::TreeConfigEnumerator) = _num_nodes(x, IdDict{typeof(x), Int}())
545+
num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}())
529546
function _num_nodes(x, d)
530547
haskey(d, x) && return 0
531548
if x.tag == ZERO || x.tag == ONE
@@ -539,37 +556,37 @@ function _num_nodes(x, d)
539556
return res
540557
end
541558

542-
function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
559+
function Base.:(==)(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
543560
return Set(collect(x)) == Set(collect(y))
544561
end
545562

546-
Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t)
563+
Base.show(io::IO, t::SumProductTree) = print_tree(io, t)
547564

548-
function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
565+
function Base.collect(x::SumProductTree{ET}) where {ET}
549566
if x.tag == ZERO
550-
return StaticElementVector{N,S,C}[]
567+
return ET[]
551568
elseif x.tag == ONE
552-
return StaticElementVector{N,S,C}[zero(StaticElementVector{N,S,C})]
569+
return [_data_one(ET)]
553570
elseif x.tag == LEAF
554-
return StaticElementVector{N,S,C}[x.data]
571+
return [x.data]
555572
elseif x.tag == SUM
556573
return vcat(collect(x.left), collect(x.right))
557574
else # PROD
558-
return vec([reduce((x,y)->x|y, si) for si in Iterators.product(collect(x.left), collect(x.right))])
575+
return vec([reduce(_data_mul, si) for si in Iterators.product(collect(x.left), collect(x.right))])
559576
end
560577
end
561578

562-
function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
579+
function Base.:+(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
563580
if x.tag == ZERO
564581
return y
565582
elseif y.tag == ZERO
566583
return x
567584
else
568-
return TreeConfigEnumerator(SUM, x, y)
585+
return SumProductTree(SUM, x, y)
569586
end
570587
end
571588

572-
function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C}
589+
function Base.:*(x::SumProductTree{ET}, y::SumProductTree{ET}) where {ET}
573590
if x.tag == ONE
574591
return y
575592
elseif y.tag == ONE
@@ -579,18 +596,18 @@ function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C})
579596
elseif y.tag == ZERO
580597
return y
581598
elseif x.tag == LEAF && y.tag == LEAF
582-
return TreeConfigEnumerator(x.data | y.data)
599+
return SumProductTree(_data_mul(x.data, y.data))
583600
else
584-
return TreeConfigEnumerator(PROD, x, y)
601+
return SumProductTree(PROD, x, y)
585602
end
586603
end
587604

588-
Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO)
589-
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ONE)
590-
Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C})
591-
Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C})
605+
Base.zero(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ZERO)
606+
Base.one(::Type{SumProductTree{ET}}) where {ET} = SumProductTree{ET}(ONE)
607+
Base.zero(::SumProductTree{ET}) where {ET} = zero(SumProductTree{ET})
608+
Base.one(::SumProductTree{ET}) where {ET} = one(SumProductTree{ET})
592609
# todo, check siblings too?
593-
function Base.iszero(t::TreeConfigEnumerator)
610+
function Base.iszero(t::SumProductTree)
594611
if t.tag == SUM
595612
iszero(t.left) && iszero(t.right)
596613
elseif t.tag == ZERO
@@ -603,9 +620,9 @@ function Base.iszero(t::TreeConfigEnumerator)
603620
end
604621

605622
"""
606-
generate_samples(t::TreeConfigEnumerator, nsamples::Int)
623+
generate_samples(t::SumProductTree, nsamples::Int)
607624
608-
Direct sampling configurations from a [`TreeConfigEnumerator`](@ref) instance.
625+
Direct sampling configurations from a [`SumProductTree`](@ref) instance.
609626
610627
Example
611628
-----------------------------
@@ -623,15 +640,15 @@ julia> all(s->is_independent_set(g, s), samples)
623640
true
624641
```
625642
"""
626-
function generate_samples(t::TreeConfigEnumerator{N,S,C}, nsamples::Int) where {N,S,C}
643+
function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET}
627644
# get length dict
628-
res = fill(zero(StaticElementVector{N,S,C}), nsamples)
645+
res = fill(_data_one(ET), nsamples)
629646
d = IdDict{typeof(t), Int}()
630647
sample_descend!(res, t, d)
631648
return res
632649
end
633650

634-
function sample_descend!(res::AbstractVector, t::TreeConfigEnumerator, d::IdDict)
651+
function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict)
635652
length(res) == 0 && return res
636653
if t.tag == LEAF
637654
res .|= Ref(t.data)
@@ -695,14 +712,14 @@ onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampl
695712
# just to make matrix transpose work
696713
Base.transpose(c::ConfigEnumerator) = c
697714
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
698-
Base.transpose(c::TreeConfigEnumerator) = c
699-
function Base.copy(c::TreeConfigEnumerator{N,S,C}) where {N,S,C}
715+
Base.transpose(c::SumProductTree) = c
716+
function Base.copy(c::SumProductTree{ET}) where {ET}
700717
if c.tag == LEAF
701-
TreeConfigEnumerator(c.data)
718+
SumProductTree(c.data)
702719
elseif c.tag == ZERO || c.tag == ONE
703-
TreeConfigEnumerator{N,S,C}(c.tag)
720+
SumProductTree{ET}(c.tag)
704721
else
705-
TreeConfigEnumerator(c.tag, c.left, c.right)
722+
SumProductTree(c.tag, c.left, c.right)
706723
end
707724
end
708725

@@ -713,7 +730,7 @@ for TYPE in [:AbstractSetNumber, :TruncatedPoly, :ExtendedTropical]
713730
end
714731

715732
# to handle power of polynomials
716-
function Base.:^(x::TreeConfigEnumerator, y::Real)
733+
function Base.:^(x::SumProductTree, y::Real)
717734
if y == 0
718735
return one(x)
719736
elseif x.tag == LEAF || x.tag == ONE || x.tag == ZERO

0 commit comments

Comments
 (0)