Skip to content

Commit a7c59d3

Browse files
committed
bit vector
1 parent 4d3713c commit a7c59d3

File tree

9 files changed

+131
-12
lines changed

9 files changed

+131
-12
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
/Manifest.toml
2+
*.swp
3+
_*

src/GraphTensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ OMEinsum.dynamic_einsum(::EinCode{ixs, iy}, xs; kwargs...) where {ixs, iy} = dyn
1515

1616
project_relative_path(xs...) = normpath(joinpath(dirname(dirname(pathof(@__MODULE__))), xs...))
1717

18+
include("bitvector.jl")
1819
include("arithematics.jl")
1920
include("networks.jl")
2021
include("graph_polynomials.jl")

src/arithematics.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export Max2Poly, Polynomial, Tropical, CountingTropical, StaticBitVector, Mod, C
33
export bitstringset_type, bitstringsampler_type
44

55
using Polynomials: Polynomial
6-
using TropicalNumbers: Tropical, CountingTropical, StaticBitVector
6+
using TropicalNumbers: Tropical, CountingTropical
77
using Mods, Primes
88

99
# patch for Tropical numbers
@@ -127,7 +127,7 @@ function Base.:*(x::ConfigEnumerator{L,C}, y::ConfigEnumerator{L,C}) where {L,C}
127127
end
128128

129129
Base.zero(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}(StaticBitVector{N,C}[])
130-
Base.one(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}([TropicalNumbers.staticfalses(StaticBitVector{N,C})])
130+
Base.one(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}([staticfalses(StaticBitVector{N,C})])
131131
Base.zero(::ConfigEnumerator{N,C}) where {N,C} = zero(ConfigEnumerator{N,C})
132132
Base.one(::ConfigEnumerator{N,C}) where {N,C} = one(ConfigEnumerator{N,C})
133133
Base.show(io::IO, x::ConfigEnumerator) = print(io, "{", join(x.data, ", "), "}")
@@ -148,8 +148,8 @@ function Base.:*(x::ConfigSampler{L,C}, y::ConfigSampler{L,C}) where {L,C}
148148
ConfigSampler(x.data | y.data)
149149
end
150150

151-
Base.zero(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(TropicalNumbers.statictrues(StaticBitVector{N,C}))
152-
Base.one(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(TropicalNumbers.staticfalses(StaticBitVector{N,C}))
151+
Base.zero(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(statictrues(StaticBitVector{N,C}))
152+
Base.one(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(staticfalses(StaticBitVector{N,C}))
153153
Base.zero(::ConfigSampler{N,C}) where {N,C} = zero(ConfigSampler{N,C})
154154
Base.one(::ConfigSampler{N,C}) where {N,C} = one(ConfigSampler{N,C})
155155

@@ -178,7 +178,7 @@ for (F,TP) in [(:bitstringset_type, :ConfigEnumerator), (:bitstringsampler_type,
178178
CountingTropical{TV, $F(n)}
179179
end
180180
function $F(n::Integer)
181-
C = TropicalNumbers._nints(n)
181+
C = _nints(n)
182182
return $TP{n, C}
183183
end
184184
end
@@ -194,5 +194,5 @@ end
194194
function onehotv(::Type{CountingTropical{TV,BS}}, x) where {TV,BS}
195195
CountingTropical{TV,BS}(one(TV), onehotv(BS, x))
196196
end
197-
onehotv(::Type{ConfigEnumerator{N,C}}, i::Integer) where {N,C} = ConfigEnumerator([TropicalNumbers.onehot(StaticBitVector{N,C}, i)])
198-
onehotv(::Type{ConfigSampler{N,C}}, i::Integer) where {N,C} = ConfigSampler(TropicalNumbers.onehot(StaticBitVector{N,C}, i))
197+
onehotv(::Type{ConfigEnumerator{N,C}}, i::Integer) where {N,C} = ConfigEnumerator([onehotv(StaticBitVector{N,C}, i)])
198+
onehotv(::Type{ConfigSampler{N,C}}, i::Integer) where {N,C} = ConfigSampler(onehotv(StaticBitVector{N,C}, i))

src/bitvector.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# StaticBitVector
2+
export StaticBitVector
3+
4+
struct StaticBitVector{N,C}
5+
data::NTuple{C,UInt64}
6+
end
7+
function StaticBitVector(x::AbstractVector)
8+
N = length(x)
9+
StaticBitVector{N,_nints(N)}((convert(BitVector, x).chunks...,))
10+
end
11+
function Base.convert(::Type{StaticBitVector{N,C}}, x::AbstractVector) where {N,C}
12+
@assert length(x) == N
13+
StaticBitVector(x)
14+
end
15+
_nints(x) = (x-1)÷64+1
16+
Base.length(::StaticBitVector{N,C}) where {N,C} = N
17+
Base.:(==)(x::StaticBitVector, y::AbstractVector) = [x...] == [y...]
18+
Base.:(==)(x::AbstractVector, y::StaticBitVector) = [x...] == [y...]
19+
Base.:(==)(x::StaticBitVector, y::StaticBitVector) = [x...] == [y...]
20+
function Base.getindex(x::StaticBitVector{N,C}, i::Integer) where {N,C}
21+
i -= 1
22+
ii = i ÷ 64
23+
(x.data[ii+1] >> (i-ii*64)) & 1
24+
end
25+
Base.:(|)(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .| y.data)
26+
Base.:(&)(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .& y.data)
27+
Base.:()(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .⊻ y.data)
28+
@generated function staticfalses(::Type{StaticBitVector{N,C}}) where {N,C}
29+
Expr(:call, :(StaticBitVector{$N,$C}), Expr(:tuple, zeros(UInt64, C)...))
30+
end
31+
@generated function statictrues(::Type{StaticBitVector{N,C}}) where {N,C}
32+
Expr(:call, :(StaticBitVector{$N,$C}), Expr(:tuple, fill(typemax(UInt64), C)...))
33+
end
34+
function onehotv(::Type{StaticBitVector{N,C}}, i) where {N,C}
35+
x = falses(N)
36+
x[i] = true
37+
return StaticBitVector(x)
38+
end
39+
function Base.iterate(x::StaticBitVector{N,C}, state=1) where {N,C}
40+
if state > N
41+
return nothing
42+
else
43+
return x[state], state+1
44+
end
45+
end
46+

src/graph_polynomials.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,25 @@ function neighbortensor(x::T, d::Int) where T
195195
return t
196196
end
197197

198-
graph_polynomial_maxorder(mi::MaximalIndependence; usecuda) = Int(sum(contractx(mi, TropicalF64(1.0); usecuda=usecuda)).n)
198+
graph_polynomial_maxorder(mi::MaximalIndependence; usecuda) = Int(sum(contractx(mi, TropicalF64(1.0); usecuda=usecuda)).n)
199+
200+
### spin glass problem ###
201+
function generate_tensors(fx, gp::SpinGlass{2})
202+
flatten_code = flatten(gp.code)
203+
ixs = getixs(flatten_code)
204+
n = length(labels(flatten_code))
205+
T = typeof(fx(ixs[1][1]))
206+
return Tuple(map(enumerate(ixs)) do (i, ix)
207+
if i <= n
208+
spinglassv(one(T))
209+
else
210+
spinglassb(fx(ix)) # if n!=2, it corresponds to set packing problem.
211+
end
212+
end)
213+
end
214+
function spinglassb(expJ::T) where T
215+
return T[one(T) expJ; expJ one(T)]
216+
end
217+
spinglassv(h::T) where T = T[one(T), h]
218+
219+
graph_polynomial_maxorder(mi::SpinGlass; usecuda) = Int(sum(contractx(mi, TropicalF64(1.0); usecuda=usecuda)).n)

src/networks.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export Independence, MaximalIndependence, Matching, Coloring, optimize_code, set_packing
1+
export Independence, MaximalIndependence, Matching, Coloring, optimize_code, set_packing, SpinGlass
22
const EinTypes = Union{EinCode,NestedEinsum}
33

44
abstract type GraphProblem end
@@ -20,10 +20,26 @@ function Independence(g::SimpleGraph; outputs=(), kwargs...)
2020
end
2121

2222
"""
23-
Independence{CT<:EinTypes} <: GraphProblem
24-
Independence(graph; kwargs...)
23+
SpinGlass{Q, CT<:EinTypes} <: GraphProblem
24+
SpinGlass{Q}(graph; kwargs...)
25+
SpinGlass(graph; kwargs...)
2526
26-
Independent set problem. For `kwargs`, check `optimize_code` API.
27+
Q-state spin glass problem (or Potts model). For `kwargs`, check `optimize_code` API.
28+
When Q=2, it corresponds to the {0, 1} spin glass model.
29+
"""
30+
struct SpinGlass{Q, CT<:EinTypes} <: GraphProblem
31+
code::CT
32+
end
33+
34+
SpinGlass(g::SimpleGraph; outputs=(), kwargs...) = SpinGlass{2}(g; outputs=outputs, kwargs...)
35+
SpinGlass{Q}(g::SimpleGraph; outputs=(), kwargs...) where Q = SpinGlass{Q}(Independence(g; outputs=outputs, kwargs...).code)
36+
SpinGlass{Q}(code::EinTypes) where Q = SpinGlass{Q,typeof(code)}(code)
37+
38+
"""
39+
MaximalIndependence{CT<:EinTypes} <: GraphProblem
40+
MaximalIndependence(graph; kwargs...)
41+
42+
Maximal independent set problem. For `kwargs`, check `optimize_code` API.
2743
"""
2844
struct MaximalIndependence{CT<:EinTypes} <: GraphProblem
2945
code::CT
@@ -115,6 +131,7 @@ for T in [:Independence, :Matching, :MaximalIndependence]
115131
@eval bondsize(gp::$T) = 2
116132
end
117133
bondsize(gp::Coloring{K}) where K = K
134+
bondsize(gp::SpinGlass{Q}) where Q = Q
118135

119136
"""
120137
set_packing(sets; kwargs...)

test/bitvector.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using Test, GraphTensorNetworks
2+
using GraphTensorNetworks: statictrues, staticfalses, StaticBitVector, onehotv
3+
4+
@testset "static bit vector" begin
5+
@test statictrues(StaticBitVector{3,1}) == trues(3)
6+
@test staticfalses(StaticBitVector{3,1}) == falses(3)
7+
x = rand(Bool, 131)
8+
y = rand(Bool, 131)
9+
a = StaticBitVector(x)
10+
b = StaticBitVector(y)
11+
a2 = BitVector(x)
12+
b2 = BitVector(y)
13+
for op in [|, &, ]
14+
@test op(a, b) == op.(a2, b2)
15+
end
16+
@test onehotv(StaticBitVector{133,3}, 5) == (x = falses(133); x[5]=true; x)
17+
end
18+

test/graph_polynomials.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GraphTensorNetworks, Test, OMEinsum, OMEinsumContractionOrders
22
using Mods, Polynomials, TropicalNumbers
33
using LightGraphs, Random
4+
using GraphTensorNetworks: StaticBitVector
45

56
@testset "bond and vertex tensor" begin
67
@test GraphTensorNetworks.misb(TropicalF64) == [TropicalF64(0) TropicalF64(0); TropicalF64(0) TropicalF64(-Inf)]
@@ -66,4 +67,13 @@ end
6667
add_edge!(g, i, j)
6768
end
6869
@test graph_polynomial(Matching, Val(:polynomial), g)[] == Polynomial([1,7,13,5])
70+
end
71+
72+
@testset "spinglass" begin
73+
g = SimpleGraph(5)
74+
for (i,j) in [(1,2),(2,3),(3,4),(4,1),(1,5),(2,4)]
75+
add_edge!(g, i, j)
76+
end
77+
@test graph_polynomial(SpinGlass{2}, Val(:polynomial), g)[] == Polynomial([2,2,4,12,10,2])
78+
@test graph_polynomial(SpinGlass{2}, Val(:finitefield), g)[] == Polynomial([2,2,4,12,10,2])
6979
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using GraphTensorNetworks
22
using Test
33

4+
@testset "bitvector" begin
5+
include("bitvector.jl")
6+
end
7+
48
@testset "independence polynomial" begin
59
include("graph_polynomials.jl")
610
end

0 commit comments

Comments
 (0)