@@ -2,6 +2,9 @@ using Polynomials: Polynomial
22using TropicalNumbers: Tropical, CountingTropical
33using Mods, Primes
44using Base. Cartesian
5+ import AbstractTrees: children, printnode, print_tree
6+
7+ @enum TreeTag LEAF SUM PROD ZERO
58
69# pirate
710Base. abs (x:: Mod ) = x
@@ -275,20 +278,179 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
275278Base. zero (:: ConfigSampler{N,S,C} ) where {N,S,C} = zero (ConfigSampler{N,S,C})
276279Base. one (:: ConfigSampler{N,S,C} ) where {N,S,C} = one (ConfigSampler{N,S,C})
277280
278- # A patch to make `Polynomial{ConfigEnumerator}` work
279- function Base.:* (a:: Int , y:: ConfigEnumerator )
280- a == 0 && return zero (y)
281- a == 1 && return y
282- error (" multiplication between int and config enumerator is not defined." )
281+ # tree config enumerator
282+ """
283+ TreeConfigEnumerator{N,S,C}
284+
285+ Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
286+ and is often more memory efficient than putting the configurations in a vector.
287+ `N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
288+
289+ Fields
290+ -----------------------
291+ * `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
292+ * `data` is the element stored in a `LEAF` node.
293+ * `left` and `right` are two operands of a `SUM` or `PROD` node.
294+
295+ Example
296+ ------------------------
297+ ```jldoctest; setup=:(using GraphTensorNetworks)
298+ julia> s = TreeConfigEnumerator(bv"00111")
299+ 00111
300+
301+
302+ julia> q = TreeConfigEnumerator(bv"10000")
303+ 10000
304+
305+
306+ julia> x = s + q
307+ +
308+ ├─ 00111
309+ └─ 10000
310+
311+
312+ julia> y = x * x
313+ *
314+ ├─ +
315+ │ ├─ 00111
316+ │ └─ 10000
317+ └─ +
318+ ├─ 00111
319+ └─ 10000
320+
321+
322+ julia> collect(y)
323+ 4-element Vector{StaticBitVector{5, 1}}:
324+ 00111
325+ 10111
326+ 10111
327+ 10000
328+
329+ julia> zero(s)
330+
331+
332+
333+ julia> one(s)
334+ 00000
335+
336+
337+ ```
338+ """
339+ struct TreeConfigEnumerator{N,S,C}
340+ tag:: TreeTag
341+ data:: StaticElementVector{N,S,C}
342+ left:: TreeConfigEnumerator{N,S,C}
343+ right:: TreeConfigEnumerator{N,S,C}
344+ TreeConfigEnumerator (tag:: TreeTag , left:: TreeConfigEnumerator{N,S,C} , right:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = new {N,S,C} (tag, zero (StaticElementVector{N,S,C}), left, right)
345+ function TreeConfigEnumerator (data:: StaticElementVector{N,S,C} ) where {N,S,C}
346+ new {N,S,C} (LEAF, data)
347+ end
348+ function TreeConfigEnumerator {N,S,C} (tag:: TreeTag ) where {N,S,C}
349+ @assert tag === ZERO
350+ return new {N,S,C} (tag)
351+ end
283352end
284- function Base.:* (a:: Int , y:: ConfigSampler )
285- a == 0 && return zero (y)
286- a == 1 && return y
287- error (" multiplication between int and config sampler is not defined." )
353+
354+ # AbstractTree APIs
355+ function children (t:: TreeConfigEnumerator )
356+ if isdefined (t, :left )
357+ if isdefined (t, :right )
358+ return [t. left, t. right]
359+ else
360+ return [t. left]
361+ end
362+ else
363+ if isdefined (t, :right )
364+ return [t. right]
365+ else
366+ return typeof (t)[]
367+ end
368+ end
369+ end
370+ function printnode (io:: IO , t:: TreeConfigEnumerator )
371+ if t. tag === LEAF
372+ print (io, t. data)
373+ elseif t. tag === ZERO
374+ print (io, " " )
375+ elseif t. tag === SUM
376+ print (io, " +" )
377+ else # PROD
378+ print (io, " *" )
379+ end
380+ end
381+
382+ function Base. length (x:: TreeConfigEnumerator )
383+ if x. tag === SUM
384+ return length (x. left) + length (x. right)
385+ elseif x. tag === PROD
386+ return length (x. left) * length (x. right)
387+ elseif x. tag === ZERO
388+ return 0
389+ else
390+ return 1
391+ end
392+ end
393+
394+ function num_nodes (x:: TreeConfigEnumerator )
395+ x. tag == ZERO && return 1
396+ x. tag == LEAF && return 1
397+ return num_nodes (x. left) + num_nodes (x. right) + 1
398+ end
399+
400+ function Base.:(== )(x:: TreeConfigEnumerator{N,S,C} , y:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
401+ return Set (collect (x)) == Set (collect (y))
402+ end
403+
404+ Base. show (io:: IO , t:: TreeConfigEnumerator ) = print_tree (io, t)
405+
406+ function Base. collect (x:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
407+ if x. tag == ZERO
408+ return StaticElementVector{N,S,C}[]
409+ elseif x. tag == LEAF
410+ return StaticElementVector{N,S,C}[x. data]
411+ elseif x. tag == SUM
412+ return vcat (collect (x. left), collect (x. right))
413+ else # PROD
414+ return vec ([reduce ((x,y)-> x| y, si) for si in Iterators. product (collect (x. left), collect (x. right))])
415+ end
416+ end
417+
418+ function Base.:+ (x:: TreeConfigEnumerator{N,S,C} , y:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
419+ TreeConfigEnumerator (SUM, x, y)
420+ end
421+
422+ function Base.:* (x:: TreeConfigEnumerator{L,S,C} , y:: TreeConfigEnumerator{L,S,C} ) where {L,S,C}
423+ TreeConfigEnumerator (PROD, x, y)
424+ end
425+
426+ Base. zero (:: Type{TreeConfigEnumerator{N,S,C}} ) where {N,S,C} = TreeConfigEnumerator {N,S,C} (ZERO)
427+ Base. one (:: Type{TreeConfigEnumerator{N,S,C}} ) where {N,S,C} = TreeConfigEnumerator (zero (StaticElementVector{N,S,C}))
428+ Base. zero (:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = zero (TreeConfigEnumerator{N,S,C})
429+ Base. one (:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = one (TreeConfigEnumerator{N,S,C})
430+ # todo, check siblings too?
431+ function Base. iszero (t:: TreeConfigEnumerator )
432+ if t. TAG == SUM
433+ iszero (t. left) && iszero (t. right)
434+ elseif t. TAG == ZERO
435+ true
436+ elseif t. TAG == LEAF
437+ false
438+ else
439+ iszero (t. left) || iszero (t. right)
440+ end
441+ end
442+
443+ # A patch to make `Polynomial{ConfigEnumerator}` work
444+ for T in [:ConfigEnumerator , :ConfigSampler , :TreeConfigEnumerator ]
445+ @eval function Base.:* (a:: Int , y:: $T )
446+ a == 0 && return zero (y)
447+ a == 1 && return y
448+ error (" multiplication between int and `$(typeof (y)) ` is not defined." )
449+ end
288450end
289451
290452# convert from counting type to bitstring type
291- for (F,TP) in [(:set_type , :ConfigEnumerator ), (:sampler_type , :ConfigSampler )]
453+ for (F,TP) in [(:set_type , :ConfigEnumerator ), (:sampler_type , :ConfigSampler ), ( :treeset_type , :TreeConfigEnumerator ) ]
292454 @eval begin
293455 function $F (:: Type{T} , n:: Int , nflavor:: Int ) where {OT, K, T<: TruncatedPoly{K,C,OT} where C}
294456 TruncatedPoly{K, $ F (n,nflavor),OT}
@@ -312,12 +474,24 @@ end
312474
313475# utilities for creating onehot vectors
314476onehotv (:: Type{ConfigEnumerator{N,S,C}} , i:: Integer , v) where {N,S,C} = ConfigEnumerator ([onehotv (StaticElementVector{N,S,C}, i, v)])
477+ onehotv (:: Type{TreeConfigEnumerator{N,S,C}} , i:: Integer , v) where {N,S,C} = TreeConfigEnumerator (onehotv (StaticElementVector{N,S,C}, i, v))
315478onehotv (:: Type{ConfigSampler{N,S,C}} , i:: Integer , v) where {N,S,C} = ConfigSampler (onehotv (StaticElementVector{N,S,C}, i, v))
479+ # just to make matrix transpose work
316480Base. transpose (c:: ConfigEnumerator ) = c
317481Base. copy (c:: ConfigEnumerator ) = ConfigEnumerator (copy (c. data))
482+ Base. transpose (c:: TreeConfigEnumerator ) = c
483+ function Base. copy (c:: TreeConfigEnumerator )
484+ if c. tag == LEAF
485+ TreeConfigEnumerator (c. data)
486+ elseif c. tag == ZERO
487+ TreeConfigEnumerator (c. tag)
488+ else
489+ TreeConfigEnumerator (c. tag, c. left, c. right)
490+ end
491+ end
318492
319493# Handle boolean, this is a patch for CUDA matmul
320- for TYPE in [:ConfigEnumerator , :ConfigSampler , :TruncatedPoly ]
494+ for TYPE in [:ConfigEnumerator , :ConfigSampler , :TruncatedPoly , :TreeConfigEnumerator ]
321495 @eval Base.:* (a:: Bool , y:: $TYPE ) = a ? y : zero (y)
322496 @eval Base.:* (y:: $TYPE , a:: Bool ) = a ? y : zero (y)
323497end
0 commit comments