Skip to content

Commit 7c4c28e

Browse files
committed
refactor
1 parent 5b5aaba commit 7c4c28e

File tree

10 files changed

+267
-285
lines changed

10 files changed

+267
-285
lines changed

src/GraphTensorNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ OMEinsum.dynamic_einsum(::EinCode{ixs, iy}, xs; kwargs...) where {ixs, iy} = dyn
1313
project_relative_path(xs...) = normpath(joinpath(dirname(dirname(pathof(@__MODULE__))), xs...))
1414

1515
include("arithematics.jl")
16-
include("independence_polynomial.jl")
16+
include("graph_polynomials.jl")
1717
include("configurations.jl")
1818
include("graphs.jl")
1919
include("bounding.jl")

src/arithematics.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
export is_commutative_semiring
2+
export Max2Poly, Polynomial, Tropical, CountingTropical, ConfigTropical, StaticBitVector, Mod, ConfigEnumerator
3+
4+
using Polynomials: Polynomial
5+
using TropicalNumbers: Tropical, CountingTropical, ConfigTropical, StaticBitVector
6+
using Mods, Primes
7+
8+
# pirate
9+
Base.abs(x::Mod) = x
10+
Base.isless(x::Mod{N}, y::Mod{N}) where N = mod(x.val, N) < mod(y.val, N)
11+
212

313
# this function is used for testing
414
function is_commutative_semiring(a::T, b::T, c::T) where T
@@ -48,8 +58,6 @@ function is_commutative_semiring(a::T, b::T, c::T) where T
4858
return true
4959
end
5060

51-
export Max2Poly
52-
5361
# get maximum two countings (polynomial truncated to largest two orders)
5462
struct Max2Poly{T} <: Number
5563
a::T
@@ -81,3 +89,41 @@ Base.one(::Type{Max2Poly{T}}) where T = Max2Poly(zero(T), one(T), 0.0)
8189
Base.zero(::Max2Poly{T}) where T = zero(Max2Poly{T})
8290
Base.one(::Max2Poly{T}) where T = one(Max2Poly{T})
8391

92+
struct ConfigEnumerator{N,C}
93+
data::Vector{StaticBitVector{N,C}}
94+
end
95+
96+
Base.length(x::ConfigEnumerator{N}) where N = length(x.data)
97+
Base.:(==)(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C} = x.data == y.data
98+
99+
function Base.:+(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C}
100+
length(x) == 0 && return y
101+
length(y) == 0 && return x
102+
return ConfigEnumerator{N,C}(vcat(x.data, y.data))
103+
end
104+
105+
function Base.:*(x::ConfigEnumerator{L,C}, y::ConfigEnumerator{L,C}) where {L,C}
106+
M, N = length(x), length(y)
107+
M == 0 && return x
108+
N == 0 && return y
109+
z = Vector{StaticBitVector{L,C}}(undef, M*N)
110+
@inbounds for j=1:N, i=1:M
111+
z[(j-1)*M+i] = x.data[i] | y.data[j]
112+
end
113+
return ConfigEnumerator{L,C}(z)
114+
end
115+
116+
Base.zero(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}(StaticBitVector{N,C}[])
117+
Base.one(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}([TropicalNumbers.staticfalses(StaticBitVector{N,C})])
118+
Base.zero(::ConfigEnumerator{N,C}) where {N,C} = zero(ConfigEnumerator{N,C})
119+
Base.one(::ConfigEnumerator{N,C}) where {N,C} = one(ConfigEnumerator{N,C})
120+
Base.show(io::IO, x::ConfigEnumerator) = print(io, "{", join(x.data, ", "), "}")
121+
Base.show(io::IO, ::MIME"text/plain", x::ConfigEnumerator) = Base.show(io, x)
122+
123+
# patch
124+
125+
function Base.:*(a::Int, y::ConfigEnumerator)
126+
a == 0 && return zero(y)
127+
a == 1 && return y
128+
error("multiplication between int and config enumerator is not defined.")
129+
end

src/configurations.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,5 @@
11
export mis_config, ConfigEnumerator, ConfigTropical
22

3-
struct ConfigEnumerator{N,C}
4-
data::Vector{StaticBitVector{N,C}}
5-
end
6-
7-
Base.length(x::ConfigEnumerator{N}) where N = length(x.data)
8-
Base.:(==)(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C} = x.data == y.data
9-
10-
function Base.:+(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C}
11-
length(x) == 0 && return y
12-
length(y) == 0 && return x
13-
return ConfigEnumerator{N,C}(vcat(x.data, y.data))
14-
end
15-
16-
function Base.:*(x::ConfigEnumerator{L,C}, y::ConfigEnumerator{L,C}) where {L,C}
17-
M, N = length(x), length(y)
18-
M == 0 && return x
19-
N == 0 && return y
20-
z = Vector{StaticBitVector{L,C}}(undef, M*N)
21-
@inbounds for j=1:N, i=1:M
22-
z[(j-1)*M+i] = x.data[i] | y.data[j]
23-
end
24-
return ConfigEnumerator{L,C}(z)
25-
end
26-
27-
Base.zero(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}(StaticBitVector{N,C}[])
28-
Base.one(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}([TropicalNumbers.staticfalses(StaticBitVector{N,C})])
29-
Base.zero(::ConfigEnumerator{N,C}) where {N,C} = zero(ConfigEnumerator{N,C})
30-
Base.one(::ConfigEnumerator{N,C}) where {N,C} = one(ConfigEnumerator{N,C})
31-
Base.show(io::IO, x::ConfigEnumerator) = print(io, "{", join(x.data, ", "), "}")
32-
Base.show(io::IO, ::MIME"text/plain", x::ConfigEnumerator) = Base.show(io, x)
33-
34-
# patch
35-
36-
function Base.:*(a::Int, y::ConfigEnumerator)
37-
a == 0 && return zero(y)
38-
a == 1 && return y
39-
error("multiplication between int and config enumerator is not defined.")
40-
end
41-
423
function symbols(::EinCode{ixs}) where ixs
434
res = []
445
for ix in ixs

src/graph_polynomials.jl

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
using Polynomials
2+
using OMEinsum: NestedEinsum, getixs, getiy
3+
using FFTW
4+
using LightGraphs
5+
6+
export contractx, graph_polynomial, contraction_code
7+
export Independence, MaximalIndependence, Matching, Coloring
8+
const EinTypes = Union{EinCode,NestedEinsum}
9+
10+
struct Independence end
11+
struct MaximalIndependence end
12+
struct Matching end
13+
struct Coloring end
14+
15+
function graph_polynomial(which, approach::Val, g::SimpleGraph; method=:kahypar, sc_target=17, max_group_size=40, nrepeat=10, imbalances=0.0:0.01:0.2, kwargs...)
16+
code = contraction_code(which, g; method=method, sc_target=sc_target, max_group_size=max_group_size, nrepeat=nrepeat, imbalances=imbalances)
17+
graph_polynomial(which, approach, code; kwargs...)
18+
end
19+
20+
function graph_polynomial(which, ::Val{:fft}, code::EinTypes; usecuda=false, maxorder=graph_polynomial_maxorder(which, code; usecuda=usecuda), r=1.0)
21+
ω = exp(-2im*π/(maxorder+1))
22+
xs = r .* collect.^ (0:maxorder))
23+
ys = [asscalar(contractx(which, x, code; usecuda=usecuda)) for x in xs]
24+
Polynomial(ifft(ys) ./ (r .^ (0:maxorder)))
25+
end
26+
27+
function graph_polynomial(which, ::Val{:fitting}, code::EinTypes; usecuda=false,
28+
maxorder = graph_polynomial_maxorder(which, code; usecuda=usecuda))
29+
xs = (0:maxorder)
30+
ys = [asscalar(contractx(which, x, code; usecuda=usecuda)) for x in xs]
31+
fit(xs, ys, maxorder)
32+
end
33+
34+
function graph_polynomial(which, ::Val{:polynomial}, code::EinTypes; usecuda=false)
35+
@assert !usecuda "Polynomial type can not be computed on GPU!"
36+
contractx(which, Polynomial([0, 1.0]), code)
37+
end
38+
39+
function _polynomial_single(which, ::Type{T}, code::EinTypes; usecuda, maxorder) where T
40+
xs = 0:maxorder
41+
ys = [asscalar(contractx(which, T(x), code; usecuda=usecuda)) for x in xs]
42+
A = zeros(T, maxorder+1, maxorder+1)
43+
for j=1:maxorder+1, i=1:maxorder+1
44+
A[j,i] = T(xs[j])^(i-1)
45+
end
46+
A \ T.(ys)
47+
end
48+
49+
function graph_polynomial(which, ::Val{:finitefield}, code::EinTypes; usecuda=false, maxorder=graph_polynomial_maxorder(which, code; usecuda=usecuda), max_iter=100)
50+
TI = Int32 # Int 32 is faster
51+
N = typemax(TI)
52+
YS = []
53+
local res
54+
for k = 1:max_iter
55+
N = prevprime(N-TI(1))
56+
T = Mods.Mod{N,TI}
57+
rk = _polynomial_single(which, T, code; usecuda=usecuda, maxorder=maxorder)
58+
push!(YS, rk)
59+
if maxorder==1
60+
return Polynomial(Mods.value.(YS[1]))
61+
elseif k != 1
62+
ra = improved_counting(YS[1:end-1])
63+
res = improved_counting(YS)
64+
ra == res && return Polynomial(res)
65+
end
66+
end
67+
@warn "result is potentially inconsistent."
68+
return Polynomial(res)
69+
end
70+
function improved_counting(sequences)
71+
map(yi->Mods.CRT(yi...), zip(sequences...))
72+
end
73+
74+
function contraction_code(which, g::SimpleGraph; method=:kahypar, sc_target=17, max_group_size=40, nrepeat=10, imbalances=0.0:0.001:0.8)
75+
_optimize_code(_code(which, g), method, sc_target, max_group_size, nrepeat, imbalances)
76+
end
77+
function _optimize_code(code, method, sc_target, max_group_size, nrepeat, imbalances)
78+
size_dict = Dict([s=>2 for s in symbols(code)])
79+
optcode = if method == :kahypar
80+
optimize_kahypar(code, size_dict; sc_target=sc_target, max_group_size=max_group_size, imbalances=imbalances)
81+
elseif method == :greedy
82+
optimize_greedy(code, size_dict; nrepeat=nrepeat)
83+
else
84+
ArgumentError("optimizer `$method` not defined.")
85+
end
86+
println("time/space complexity is $(OMEinsum.timespace_complexity(optcode, size_dict))")
87+
return optcode
88+
end
89+
90+
############### Problem specific implementations ################
91+
### independent set ###
92+
function _code(::Independence, g::SimpleGraph)
93+
EinCode(([(i,) for i in LightGraphs.vertices(g)]..., # labels for edge tensors
94+
[minmax(e.src,e.dst) for e in LightGraphs.edges(g)]...), ()) # labels for vertex tensors
95+
end
96+
97+
function contractx(::Independence, x::T, code::EinTypes; usecuda=false) where {T}
98+
tensors = map(getixs(flatten(code))) do ix
99+
# if the tensor rank is 1, create a vertex tensor.
100+
# otherwise the tensor rank must be 2, create a bond tensor.
101+
t = length(ix)==1 ? misv(T, x) : misb(T)
102+
usecuda ? CuArray(t) : t
103+
end
104+
dynamic_einsum(code, tensors)
105+
end
106+
misb(::Type{T}) where T = [one(T) one(T); one(T) zero(T)]
107+
misv(::Type{T}, val) where T = [one(T), convert(T, val)]
108+
109+
graph_polynomial_maxorder(::Independence, code; usecuda) = Int(sum(contractx(Independence(), TropicalF64(1.0), code; usecuda=usecuda)).n)
110+
111+
### coloring ###
112+
_code(::Coloring, g::SimpleGraph) = independence_code(args...; kwargs...)
113+
function contractx(::Coloring, xs, code::EinTypes; usecuda=false)
114+
tensors = map(getixs(flatten(code))) do ix
115+
# if the tensor rank is 1, create a vertex tensor.
116+
# otherwise the tensor rank must be 2, create a bond tensor.
117+
t = length(ix)==1 ? coloringv(collect(xs)) : coloringb(eltype(xs), length(xs))
118+
usecuda ? CuArray(t) : t
119+
end
120+
dynamic_einsum(code, tensors)
121+
end
122+
123+
# coloring bond tensor
124+
function coloringb(::Type{T}, k::Int) where T
125+
x = ones(T, k, k)
126+
for i=1:k
127+
x[i,i] = zero(T)
128+
end
129+
return x
130+
end
131+
# coloring vertex tensor
132+
coloringv(vals::Vector{T}) where T = vals
133+
134+
### matching ###
135+
function _code(::Matching, g::SimpleGraph)
136+
EinCode(([(minmax(e.src,e.dst),) for e in LightGraphs.edges(g)]..., # labels for edge tensors
137+
[([minmax(i,j) for j in neighbors(g, i)]...,) for i in LightGraphs.vertices(g)]...,), ()) # labels for vertex tensors
138+
end
139+
140+
function contractx(::Matching, x::T, optcode::EinTypes; usecuda=false) where T
141+
ixs = OMEinsum.getixs(flatten(optcode))
142+
n = length(unique(Iterators.flatten(ixs))) # number of vertices
143+
tensors = []
144+
for i=1:length(ixs)
145+
if i<=n
146+
@assert length(ixs[i]) == 1
147+
t = T[one(T), x]
148+
else
149+
t = match_tensor(T, length(ixs[i]))
150+
end
151+
push!(tensors, usecuda ? CuArray(t) : t)
152+
end
153+
optcode(tensors...)
154+
end
155+
function match_tensor(::Type{T}, n::Int) where T
156+
t = zeros(T, fill(2, n)...)
157+
for ci in CartesianIndices(t)
158+
if sum(ci.I .- 1) <= 1
159+
t[ci] = one(T)
160+
end
161+
end
162+
return t
163+
end
164+
165+
graph_polynomial_maxorder(::Matching, code; usecuda) = Int(sum(contractx(Matching(), TropicalF64(1.0), code; usecuda=usecuda)).n)
166+
167+
### maximal independent set ###
168+
function _code(::MaximalIndependence, g::SimpleGraph)
169+
EinCode(([(LightGraphs.neighbors(g, v)..., v) for v in LightGraphs.vertices(g)]...,), ())
170+
end
171+
172+
function contractx(::MaximalIndependence, x::T, optcode::EinTypes; usecuda=false) where T
173+
ixs = OMEinsum.getixs(flatten(optcode))
174+
tensors = map(ixs) do ix
175+
t = neighbortensor(x, length(ix))
176+
usecuda ? CuArray(t) : t
177+
end
178+
dynamic_einsum(optcode, tensors)
179+
end
180+
function neighbortensor(x::T, d::Int) where T
181+
t = zeros(T, fill(2, d)...)
182+
for i = 2:1<<(d-1)
183+
t[i] = one(T)
184+
end
185+
t[1<<(d-1)+1] = x
186+
return t
187+
end
188+
189+
graph_polynomial_maxorder(::MaximalIndependence, code; usecuda) = Int(sum(contractx(MaximalIndependence(), TropicalF64(1.0), code; usecuda=usecuda)).n)

0 commit comments

Comments
 (0)