@@ -408,30 +408,27 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
408408Base. zero (:: ConfigSampler{N,S,C} ) where {N,S,C} = zero (ConfigSampler{N,S,C})
409409Base. one (:: ConfigSampler{N,S,C} ) where {N,S,C} = one (ConfigSampler{N,S,C})
410410
411- # tree config enumerator
412- # it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
413411"""
414- TreeConfigEnumerator{N,S,C } <: AbstractSetNumber
412+ SumProductTree{ET } <: AbstractSetNumber
415413
416414Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
417415and is often more memory efficient than putting the configurations in a vector.
418416One can use [`generate_samples`](@ref) to sample configurations from this tree structure efficiently.
419- `N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
420417
421418Fields
422419-----------------------
423- * `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
420+ * `tag` is one of `ZERO`, `ONE`, ` LEAF`, `SUM`, `PROD`.
424421* `data` is the element stored in a `LEAF` node.
425422* `left` and `right` are two operands of a `SUM` or `PROD` node.
426423
427424Example
428425------------------------
429426```jldoctest; setup=:(using GraphTensorNetworks)
430- julia> s = TreeConfigEnumerator (bv"00111")
427+ julia> s = SumProductTree (bv"00111")
43142800111
432429
433430
434- julia> q = TreeConfigEnumerator (bv"10000")
431+ julia> q = SumProductTree (bv"10000")
43543210000
436433
437434
@@ -469,44 +466,64 @@ julia> one(s)
469466
470467```
471468"""
472- mutable struct TreeConfigEnumerator{N,S,C } <: AbstractSetNumber
469+ mutable struct SumProductTree{ET } <: AbstractSetNumber
473470 tag:: TreeTag
474- data:: StaticElementVector{N,S,C}
475- left:: TreeConfigEnumerator{N,S,C}
476- right:: TreeConfigEnumerator{N,S,C}
477- TreeConfigEnumerator (tag:: TreeTag , left:: TreeConfigEnumerator{N,S,C} , right:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = new {N,S,C} (tag, zero (StaticElementVector{N,S,C}), left, right)
478- function TreeConfigEnumerator (data:: StaticElementVector{N,S,C} ) where {N,S,C}
479- new {N,S,C} (LEAF, data)
480- end
481- function TreeConfigEnumerator {N,S,C} (tag:: TreeTag ) where {N,S,C}
471+ data:: ET
472+ left:: SumProductTree{ET}
473+ right:: SumProductTree{ET}
474+ # zero(ET) can be undef
475+ function SumProductTree (tag:: TreeTag , left:: SumProductTree{ET} , right:: SumProductTree{ET} ) where {ET}
476+ res = new {ET} (tag)
477+ res. left = left
478+ res. right = right
479+ return res
480+ end
481+ function SumProductTree (data:: ET ) where ET
482+ return new {ET} (LEAF, data)
483+ end
484+ function SumProductTree {ET} (tag:: TreeTag ) where {ET}
482485 @assert tag === ZERO || tag === ONE
483- return new {N,S,C } (tag)
486+ return new {ET } (tag)
484487 end
485488end
489+ # these two interfaces must be implemented in order to collect elements
490+ _data_mul (x:: StaticElementVector , y:: StaticElementVector ) = x | y
491+ _data_one (:: Type{T} ) where T<: StaticElementVector = zero (T) # NOTE: might be optional
492+
493+ """
494+ TreeConfigEnumerator{N,S,C}
495+
496+ An alias for [`SumProductTree`](@ref)`{StaticElementVector{N, S, C}}`,
497+ which is a useful element type for configuration enumeration.
498+ """
499+ const TreeConfigEnumerator{N,S,C} = SumProductTree{StaticElementVector{N,S,C}}
500+ TreeConfigEnumerator (data:: StaticElementVector ) = SumProductTree (data)
501+ TreeConfigEnumerator (tag:: TreeTag , left:: TreeConfigEnumerator{N,S,C} , right:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = SumProductTree (tag, left, right)
486502
487503# AbstractTree APIs
488- function children (t:: TreeConfigEnumerator )
504+ function children (t:: SumProductTree )
489505 if t. tag == ZERO || t. tag == LEAF || t. tag == ONE
490506 return typeof (t)[]
491507 else
492508 return [t. left, t. right]
493509 end
494510end
495- function printnode (io:: IO , t:: TreeConfigEnumerator{N,S,C } ) where {N,S,C }
511+ function printnode (io:: IO , t:: SumProductTree{ET } ) where {ET }
496512 if t. tag === LEAF
497513 print (io, t. data)
498514 elseif t. tag === ZERO
499515 print (io, " ∅" )
500516 elseif t. tag === ONE
501- print (io, zero (StaticElementVector{N,S,C} ))
517+ print (io, _data_one (ET ))
502518 elseif t. tag === SUM
503519 print (io, " +" )
504520 else # PROD
505521 print (io, " *" )
506522 end
507523end
508524
509- Base. length (x:: TreeConfigEnumerator ) = _length! (x, IdDict {typeof(x), Int} ())
525+ # it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
526+ Base. length (x:: SumProductTree ) = _length! (x, IdDict {typeof(x), Int} ())
510527
511528function _length! (x, d)
512529 haskey (d, x) && return d[x]
@@ -525,7 +542,7 @@ function _length!(x, d)
525542 end
526543end
527544
528- num_nodes (x:: TreeConfigEnumerator ) = _num_nodes (x, IdDict {typeof(x), Int} ())
545+ num_nodes (x:: SumProductTree ) = _num_nodes (x, IdDict {typeof(x), Int} ())
529546function _num_nodes (x, d)
530547 haskey (d, x) && return 0
531548 if x. tag == ZERO || x. tag == ONE
@@ -539,37 +556,37 @@ function _num_nodes(x, d)
539556 return res
540557end
541558
542- function Base.:(== )(x:: TreeConfigEnumerator{N,S,C } , y:: TreeConfigEnumerator{N,S,C } ) where {N,S,C }
559+ function Base.:(== )(x:: SumProductTree{ET } , y:: SumProductTree{ET } ) where {ET }
543560 return Set (collect (x)) == Set (collect (y))
544561end
545562
546- Base. show (io:: IO , t:: TreeConfigEnumerator ) = print_tree (io, t)
563+ Base. show (io:: IO , t:: SumProductTree ) = print_tree (io, t)
547564
548- function Base. collect (x:: TreeConfigEnumerator{N,S,C } ) where {N,S,C }
565+ function Base. collect (x:: SumProductTree{ET } ) where {ET }
549566 if x. tag == ZERO
550- return StaticElementVector{N,S,C} []
567+ return ET []
551568 elseif x. tag == ONE
552- return StaticElementVector{N,S,C}[ zero (StaticElementVector{N,S,C} )]
569+ return [ _data_one (ET )]
553570 elseif x. tag == LEAF
554- return StaticElementVector{N,S,C} [x. data]
571+ return [x. data]
555572 elseif x. tag == SUM
556573 return vcat (collect (x. left), collect (x. right))
557574 else # PROD
558- return vec ([reduce ((x,y) -> x | y , si) for si in Iterators. product (collect (x. left), collect (x. right))])
575+ return vec ([reduce (_data_mul , si) for si in Iterators. product (collect (x. left), collect (x. right))])
559576 end
560577end
561578
562- function Base.:+ (x:: TreeConfigEnumerator{N,S,C } , y:: TreeConfigEnumerator{N,S,C } ) where {N,S,C }
579+ function Base.:+ (x:: SumProductTree{ET } , y:: SumProductTree{ET } ) where {ET }
563580 if x. tag == ZERO
564581 return y
565582 elseif y. tag == ZERO
566583 return x
567584 else
568- return TreeConfigEnumerator (SUM, x, y)
585+ return SumProductTree (SUM, x, y)
569586 end
570587end
571588
572- function Base.:* (x:: TreeConfigEnumerator{L,S,C } , y:: TreeConfigEnumerator{L,S,C } ) where {L,S,C }
589+ function Base.:* (x:: SumProductTree{ET } , y:: SumProductTree{ET } ) where {ET }
573590 if x. tag == ONE
574591 return y
575592 elseif y. tag == ONE
@@ -579,18 +596,18 @@ function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C})
579596 elseif y. tag == ZERO
580597 return y
581598 elseif x. tag == LEAF && y. tag == LEAF
582- return TreeConfigEnumerator ( x. data | y. data)
599+ return SumProductTree ( _data_mul ( x. data, y. data) )
583600 else
584- return TreeConfigEnumerator (PROD, x, y)
601+ return SumProductTree (PROD, x, y)
585602 end
586603end
587604
588- Base. zero (:: Type{TreeConfigEnumerator{N,S,C }} ) where {N,S,C } = TreeConfigEnumerator {N,S,C } (ZERO)
589- Base. one (:: Type{TreeConfigEnumerator{N,S,C }} ) where {N,S,C } = TreeConfigEnumerator {N,S,C } (ONE)
590- Base. zero (:: TreeConfigEnumerator{N,S,C } ) where {N,S,C } = zero (TreeConfigEnumerator{N,S,C })
591- Base. one (:: TreeConfigEnumerator{N,S,C } ) where {N,S,C } = one (TreeConfigEnumerator{N,S,C })
605+ Base. zero (:: Type{SumProductTree{ET }} ) where {ET } = SumProductTree {ET } (ZERO)
606+ Base. one (:: Type{SumProductTree{ET }} ) where {ET } = SumProductTree {ET } (ONE)
607+ Base. zero (:: SumProductTree{ET } ) where {ET } = zero (SumProductTree{ET })
608+ Base. one (:: SumProductTree{ET } ) where {ET } = one (SumProductTree{ET })
592609# todo, check siblings too?
593- function Base. iszero (t:: TreeConfigEnumerator )
610+ function Base. iszero (t:: SumProductTree )
594611 if t. tag == SUM
595612 iszero (t. left) && iszero (t. right)
596613 elseif t. tag == ZERO
@@ -603,9 +620,9 @@ function Base.iszero(t::TreeConfigEnumerator)
603620end
604621
605622"""
606- generate_samples(t::TreeConfigEnumerator , nsamples::Int)
623+ generate_samples(t::SumProductTree , nsamples::Int)
607624
608- Direct sampling configurations from a [`TreeConfigEnumerator `](@ref) instance.
625+ Direct sampling configurations from a [`SumProductTree `](@ref) instance.
609626
610627Example
611628-----------------------------
@@ -623,15 +640,15 @@ julia> all(s->is_independent_set(g, s), samples)
623640true
624641```
625642"""
626- function generate_samples (t:: TreeConfigEnumerator{N,S,C } , nsamples:: Int ) where {N,S,C }
643+ function generate_samples (t:: SumProductTree{ET } , nsamples:: Int ) where {ET }
627644 # get length dict
628- res = fill (zero (StaticElementVector{N,S,C} ), nsamples)
645+ res = fill (_data_one (ET ), nsamples)
629646 d = IdDict {typeof(t), Int} ()
630647 sample_descend! (res, t, d)
631648 return res
632649end
633650
634- function sample_descend! (res:: AbstractVector , t:: TreeConfigEnumerator , d:: IdDict )
651+ function sample_descend! (res:: AbstractVector , t:: SumProductTree , d:: IdDict )
635652 length (res) == 0 && return res
636653 if t. tag == LEAF
637654 res .| = Ref (t. data)
@@ -695,14 +712,14 @@ onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampl
695712# just to make matrix transpose work
696713Base. transpose (c:: ConfigEnumerator ) = c
697714Base. copy (c:: ConfigEnumerator ) = ConfigEnumerator (copy (c. data))
698- Base. transpose (c:: TreeConfigEnumerator ) = c
699- function Base. copy (c:: TreeConfigEnumerator{N,S,C } ) where {N,S,C }
715+ Base. transpose (c:: SumProductTree ) = c
716+ function Base. copy (c:: SumProductTree{ET } ) where {ET }
700717 if c. tag == LEAF
701- TreeConfigEnumerator (c. data)
718+ SumProductTree (c. data)
702719 elseif c. tag == ZERO || c. tag == ONE
703- TreeConfigEnumerator {N,S,C } (c. tag)
720+ SumProductTree {ET } (c. tag)
704721 else
705- TreeConfigEnumerator (c. tag, c. left, c. right)
722+ SumProductTree (c. tag, c. left, c. right)
706723 end
707724end
708725
@@ -713,7 +730,7 @@ for TYPE in [:AbstractSetNumber, :TruncatedPoly, :ExtendedTropical]
713730end
714731
715732# to handle power of polynomials
716- function Base.:^ (x:: TreeConfigEnumerator , y:: Real )
733+ function Base.:^ (x:: SumProductTree , y:: Real )
717734 if y == 0
718735 return one (x)
719736 elseif x. tag == LEAF || x. tag == ONE || x. tag == ZERO
0 commit comments