Skip to content

Commit 20549d6

Browse files
committed
...
1 parent 374bcb3 commit 20549d6

File tree

6 files changed

+60
-42
lines changed

6 files changed

+60
-42
lines changed

src/GraphTensorNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using OMEinsumContractionOrders: OMEinsum
44
using Core: Argument
55
using TropicalGEMM, TropicalNumbers
66
using OMEinsum
7-
using OMEinsum: flatten, timespace_complexity
7+
using OMEinsum: timespace_complexity
88
using LightGraphs
99

1010
export timespace_complexity, @ein_str

src/bounding.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function generate_masktree(code::Int, cache, mask, size_dict, mode=:all)
6868
CacheTree(mask, CacheTree{Bool}[])
6969
end
7070
function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all)
71-
submasks = backward_tropical(mode, OMEinsum.getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
71+
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
7272
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
7373
end
7474

@@ -96,7 +96,7 @@ function bounding_contract(@nospecialize(code::EinCode), @nospecialize(xsa), yma
9696
bounding_contract(NestedEinsum((1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
9797
end
9898
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
99-
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), xsa, size_info)
99+
size_dict = OMEinsum.get_size_dict(collect_ixs(code), xsa, size_info)
100100
# compute intermediate tensors
101101
@debug "caching einsum..."
102102
c = cached_einsum(code, xsa, size_dict)
@@ -113,7 +113,7 @@ function solution_ad(@nospecialize(code::EinCode), @nospecialize(xsa), ymask; si
113113
end
114114

115115
function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=nothing)
116-
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), xsa, size_info)
116+
size_dict = OMEinsum.get_size_dict(collect_ixs(code), xsa, size_info)
117117
# compute intermediate tensors
118118
@debug "caching einsum..."
119119
c = cached_einsum(code, xsa, size_dict)
@@ -125,7 +125,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
125125
end
126126

127127
function read_config!(code::NestedEinsum, mt, out)
128-
for (arg, ix, sibling) in zip(code.args, OMEinsum.getixs(code.eins), mt.siblings)
128+
for (arg, ix, sibling) in zip(code.args, getixs(code.eins), mt.siblings)
129129
if arg isa Int
130130
assign = convert(Array, sibling.content) # note: the content can be CuArray
131131
if length(ix) == 1

src/configurations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false)
1616
T = (all ? set_type : sampler_type)(CountingTropical{Int64}, length(syms), bondsize(gp))
1717
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
1818
xst = generate_tensors(l->TropicalF64(1.0), gp)
19-
ymask = trues(fill(2, length(OMEinsum.getiy(flatten(gp.code))))...)
19+
ymask = trues(fill(2, length(_getiy(gp.code)))...)
2020
if usecuda
2121
xst = CuArray.(xst)
2222
ymask = CuArray(ymask)
@@ -89,4 +89,4 @@ end
8989
for GP in [:Independence, :Matching, :MaximalIndependence, :Coloring]
9090
@eval symbols(gp::$GP) = labels(gp.code)
9191
end
92-
symbols(gp::MaxCut) = collect(OMEinsum.getixs(OMEinsum.flatten(gp.code)))
92+
symbols(gp::MaxCut) = collect_ixs(gp.code)

src/graph_polynomials.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ function _polynomial_single(gp::GraphProblem, ::Type{T}; usecuda, maxorder) wher
6262
return res
6363
end
6464

65+
_getiy(code::EinCode) = getiy(code)
66+
_getiy(code::NestedEinsum) = getiy(code.eins)
6567
function graph_polynomial(gp::GraphProblem, ::Val{:finitefield}; usecuda=false,
6668
maxorder=max_size(gp; usecuda=usecuda), max_iter=100)
6769
TI = Int32 # Int 32 is faster
6870
N = typemax(TI)
69-
YS = fill(Any[], (fill(bondsize(gp), length(getiy(flatten(gp.code))))...,))
71+
YS = fill(Any[], (fill(bondsize(gp), length(_getiy(gp.code)))...,))
7072
local res, respre
7173
for k = 1:max_iter
7274
N = prevprime(N-TI(1))
@@ -103,9 +105,8 @@ end
103105
############### Problem specific implementations ################
104106
### independent set ###
105107
function generate_tensors(fx, gp::Independence)
106-
flatten_code = flatten(gp.code)
107-
ixs = getixs(flatten_code)
108-
n = length(labels(flatten_code))
108+
ixs = collect_ixs(gp.code)
109+
n = length(unique!(vcat(ixs...)))
109110
T = typeof(fx(ixs[1][1]))
110111
return Tuple(map(enumerate(ixs)) do (i, ix)
111112
if i <= n
@@ -115,6 +116,23 @@ function generate_tensors(fx, gp::Independence)
115116
end
116117
end)
117118
end
119+
120+
function collect_ixs(ne::NestedEinsum)
121+
d = collect_ixs!(ne, Dict{Int,Vector{OMEinsum.labeltype(ne.eins)}}())
122+
return [d[i] for i=1:length(d)]
123+
end
124+
125+
function collect_ixs!(ne::NestedEinsum, d::Dict)
126+
for i=1:length(ne.args)
127+
if ne.args[i] isa Integer
128+
d[ne.args[i]] = collect(OMEinsum.getixs(ne.eins)[i])
129+
else
130+
collect_ixs!(ne.args[i], d)
131+
end
132+
end
133+
return d
134+
end
135+
118136
function misb(::Type{T}, n::Integer=2) where T
119137
res = zeros(T, fill(2, n)...)
120138
res[1] = one(T)
@@ -127,7 +145,7 @@ misv(val::T) where T = [one(T), val]
127145

128146
### coloring ###
129147
function generate_tensors(fx, c::Coloring{K}) where K
130-
ixs = getixs(flatten(c.code))
148+
ixs = collect_ixs(c.code)
131149
T = eltype(fx(ixs[1][1]))
132150
return map(ixs) do ix
133151
# if the tensor rank is 1, create a vertex tensor.
@@ -149,9 +167,9 @@ coloringv(vals::Vector{T}) where T = vals
149167

150168
### matching ###
151169
function generate_tensors(fx, m::Matching)
152-
ixs = OMEinsum.getixs(flatten(m.code))
170+
ixs = collect_ixs(m.code)
153171
T = typeof(fx(ixs[1][1]))
154-
n = length(unique(vcat(collect.(ixs)...))) # number of vertices
172+
n = length(unique!(vcat(ixs...))) # number of vertices
155173
tensors = []
156174
for i=1:length(ixs)
157175
if i<=n
@@ -176,7 +194,7 @@ end
176194

177195
### maximal independent set ###
178196
function generate_tensors(fx, mi::MaximalIndependence)
179-
ixs = OMEinsum.getixs(flatten(mi.code))
197+
ixs = collect_ixs(mi.code)
180198
T = eltype(fx(ixs[1][end]))
181199
return map(ixs) do ix
182200
neighbortensor(fx(ix[end]), length(ix))
@@ -193,11 +211,10 @@ end
193211

194212
### max cut/spin glass problem ###
195213
function generate_tensors(fx, gp::MaxCut)
196-
flatten_code = flatten(gp.code)
197-
ixs = getixs(flatten_code)
198-
return Tuple(map(enumerate(ixs)) do (i, ix)
214+
ixs = collect_ixs(gp.code)
215+
return map(enumerate(ixs)) do (i, ix)
199216
maxcutb(fx(ix))
200-
end)
217+
end
201218
end
202219
function maxcutb(expJ::T) where T
203220
return T[one(T) expJ; expJ one(T)]

src/networks.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,53 @@ abstract type GraphProblem end
55

66
"""
77
Independence{CT<:EinTypes} <: GraphProblem
8-
Independence(graph; kwargs...)
8+
Independence(graph; openvertices=(), kwargs...)
99
1010
Independent set problem. `kwargs` is forwarded to `optimize_code`.
1111
"""
1212
struct Independence{CT<:EinTypes} <: GraphProblem
1313
code::CT
1414
end
1515

16-
function Independence(g::SimpleGraph; outputs=(), kwargs...)
16+
function Independence(g::SimpleGraph; openvertices=(), kwargs...)
1717
rawcode = EinCode(([(i,) for i in LightGraphs.vertices(g)]..., # labels for vertex tensors
18-
[minmax(e.src,e.dst) for e in LightGraphs.edges(g)]...), outputs) # labels for edge tensors
19-
Independence(optimize_code(rawcode; kwargs...))
18+
[minmax(e.src,e.dst) for e in LightGraphs.edges(g)]...), openvertices) # labels for edge tensors
19+
code = optimize_code(rawcode; kwargs...)
20+
Independence(code)
2021
end
2122

2223
"""
2324
MaxCut{CT<:EinTypes} <: GraphProblem
24-
MaxCut(graph; kwargs...)
25+
MaxCut(graph; openvertices=(), kwargs...)
2526
2627
Max cut problem (or spin glass problem). `kwargs` is forwarded to `optimize_code`.
2728
"""
2829
struct MaxCut{CT<:EinTypes} <: GraphProblem
2930
code::CT
3031
end
31-
function MaxCut(g::SimpleGraph; outputs=(), kwargs...)
32-
rawcode = EinCode(([minmax(e.src,e.dst) for e in LightGraphs.edges(g)]...,), outputs) # labels for edge tensors
32+
function MaxCut(g::SimpleGraph; openvertices=(), kwargs...)
33+
rawcode = EinCode(([minmax(e.src,e.dst) for e in LightGraphs.edges(g)]...,), openvertices) # labels for edge tensors
3334
MaxCut(optimize_code(rawcode; kwargs...))
3435
end
3536

3637
"""
3738
MaximalIndependence{CT<:EinTypes} <: GraphProblem
38-
MaximalIndependence(graph; kwargs...)
39+
MaximalIndependence(graph; openvertices=(), kwargs...)
3940
4041
Maximal independent set problem. `kwargs` is forwarded to `optimize_code`.
4142
"""
4243
struct MaximalIndependence{CT<:EinTypes} <: GraphProblem
4344
code::CT
4445
end
4546

46-
function MaximalIndependence(g::SimpleGraph; outputs=(), kwargs...)
47-
rawcode = EinCode(([(LightGraphs.neighbors(g, v)..., v) for v in LightGraphs.vertices(g)]...,), outputs)
47+
function MaximalIndependence(g::SimpleGraph; openvertices=(), kwargs...)
48+
rawcode = EinCode(([(LightGraphs.neighbors(g, v)..., v) for v in LightGraphs.vertices(g)]...,), openvertices)
4849
MaximalIndependence(optimize_code(rawcode; kwargs...))
4950
end
5051

5152
"""
5253
Matching{CT<:EinTypes} <: GraphProblem
53-
Matching(graph; kwargs...)
54+
Matching(graph; openvertices=(), kwargs...)
5455
5556
Vertex matching problem. `kwargs` is forwarded to `optimize_code`.
5657
The matching polynomial adopts the first definition in wiki page: https://en.wikipedia.org/wiki/Matching_polynomial
@@ -63,15 +64,15 @@ struct Matching{CT<:EinTypes} <: GraphProblem
6364
code::CT
6465
end
6566

66-
function Matching(g::SimpleGraph; outputs=(), kwargs...)
67+
function Matching(g::SimpleGraph; openvertices=(), kwargs...)
6768
rawcode = EinCode(([(minmax(e.src,e.dst),) for e in LightGraphs.edges(g)]..., # labels for edge tensors
68-
[([minmax(i,j) for j in neighbors(g, i)]...,) for i in LightGraphs.vertices(g)]...,), outputs) # labels for vertex tensors
69+
[([minmax(i,j) for j in neighbors(g, i)]...,) for i in LightGraphs.vertices(g)]...,), openvertices) # labels for vertex tensors
6970
Matching(optimize_code(rawcode; kwargs...))
7071
end
7172

7273
"""
7374
Coloring{K,CT<:EinTypes} <: GraphProblem
74-
Coloring{K}(graph; kwargs...)
75+
Coloring{K}(graph; openvertices=(), kwargs...)
7576
7677
K-Coloring problem. `kwargs` is forwarded to `optimize_code`.
7778
"""
@@ -80,7 +81,7 @@ struct Coloring{K,CT<:EinTypes} <: GraphProblem
8081
end
8182
Coloring{K}(code::ET) where {K,ET<:EinTypes} = Coloring{K,ET}(code)
8283
# same network layout as independent set.
83-
Coloring{K}(g::SimpleGraph; outputs=(), kwargs...) where K = Coloring{K}(Independence(g; outputs=outputs, kwargs...).code)
84+
Coloring{K}(g::SimpleGraph; openvertices=(), kwargs...) where K = Coloring{K}(Independence(g; openvertices=openvertices, kwargs...).code)
8485

8586
"""
8687
labels(code)
@@ -89,7 +90,7 @@ Return a vector of unique labels in an Einsum token.
8990
"""
9091
function labels(code::EinTypes)
9192
res = []
92-
for ix in OMEinsum.getixs(OMEinsum.flatten(code))
93+
for ix in collect_ixs(code)
9394
for l in ix
9495
if l res
9596
push!(res, l)

src/viz.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ using Viznet
22
export vizeinsum, vizconfig
33
using Compose
44

5-
function vizconfig(g::SimpleGraph; locs, config=zeros(Int, length(locs)), unit=1.0, graphsize=12cm)
6-
vizconfig([string(v)=>locs[v] for v in LightGraphs.vertices(g)], [(e.src, e.dst) for e in edges(g)]; config=config, unit=unit, graphsize=graphsize)
5+
function vizconfig(g::SimpleGraph; locs, config=zeros(Int, length(locs)), unit=1.0, graphsize=12cm, radius=0.03)
6+
vizconfig([string(v)=>locs[v] for v in LightGraphs.vertices(g)], [(e.src, e.dst) for e in edges(g)]; config=config, unit=unit, graphsize=graphsize, radius=radius)
77
end
88

9-
function vizconfig(nodes, edges; config=zeros(Int, length(nodes)), unit=1.0, graphsize=12cm)
9+
function vizconfig(nodes, edges; config=zeros(Int, length(nodes)), unit=1.0, graphsize=12cm, radius=0.03)
1010
tb = textstyle(:default, fill("white"), fontsize(10pt*unit))
11-
nb = nodestyle(:circle, fill("black"), r=0.03*unit)
12-
nb2 = nodestyle(:circle, fill("red"),r=0.03*unit)
11+
nb = nodestyle(:circle, fill("black"), r=radius*unit)
12+
nb2 = nodestyle(:circle, fill("red"),r=radius*unit)
1313
eb = bondstyle(:default, linewidth(0.4mm*unit))
1414
img = canvas() do
1515
for (i, (t, p)) in enumerate(nodes)
@@ -69,11 +69,11 @@ function vizeinsum(::EinCode{ixs, iy}, locs::AbstractVector{<:Pair}; kwargs...)
6969
vizeinsum(ixs, iy, Dict(locs); kwargs...)
7070
end
7171
function vizeinsum(code::NestedEinsum, locs::AbstractVector{<:Pair}; kwargs...)
72-
vizeinsum(flatten(code), locs; kwargs...)
72+
vizeinsum(OMEinsum.flatten(code), locs; kwargs...)
7373
end
7474
function vizeinsum(ixs::Tuple, iy::Tuple, locs::Dict; kwargs...)
7575
legs = unique!([Iterators.flatten(ixs)..., iy...])
7676
nodes = [l=>locs[l] for l in legs]
7777
edges = [map(i->findfirst(==(i), legs), ix) for ix in ixs]
7878
vizeinsum(nodes, edges; config=[l iy for l in legs], kwargs...)
79-
end
79+
end

0 commit comments

Comments
 (0)