Skip to content

Commit cf9d2ba

Browse files
committed
initial
1 parent 07bd6f3 commit cf9d2ba

16 files changed

+1020
-4
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.0'
1413
- '1.6'
1514
os:
1615
- ubuntu-latest

Project.toml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,39 @@ uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
33
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
9+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
10+
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
Mods = "7475f97c-0381-53b1-977b-4c60186c8d62"
13+
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
14+
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
15+
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
16+
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
17+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
18+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
19+
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
20+
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
21+
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
22+
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"
23+
624
[compat]
25+
CUDA = "3.3"
26+
Compose = "0.9"
27+
LightGraphs = "1.3"
28+
Mods = "1.3"
29+
OMEinsum = "0.4"
30+
OMEinsumContractionOrders = "0.1"
31+
Polynomials = "2.0"
32+
Primes = "0.5"
33+
Requires = "1"
34+
TropicalGEMM = "0.1"
35+
TropicalNumbers = "0.4"
36+
TupleTools = "1.2"
37+
Viznet = "0.3"
38+
FFTW = "1.4"
739
julia = "1"
840

941
[extras]

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
11
# GraphTensorNetworks
22

33
[![Build Status](https://github.com/HappyDiode/GraphTensorNetworks.jl/workflows/CI/badge.svg)](https://github.com/HappyDiode/GraphTensorNetworks.jl/actions)
4+
5+
## Installation
6+
<p>
7+
GraphTensorNetworks is a &nbsp;
8+
<a href="https://julialang.org">
9+
<img src="https://julialang.org/favicon.ico" width="16em">
10+
Julia Language
11+
</a>
12+
&nbsp; package. To install GraphTensorNetworks,
13+
please <a href="https://docs.julialang.org/en/v1/manual/getting-started/">open
14+
Julia's interactive session (known as REPL)</a> and press <kbd>]</kbd> key in the REPL to use the package mode, then type the following command
15+
</p>
16+
17+
```julia
18+
pkg> add GraphTensorNetworks
19+
```
20+
21+
Please use Julia-1.7, otherwise you will suffer from huge overhead when contracting large tensor networks. If you have to use a lower version,
22+
you can avoid the overhead by overriding the `permutedims!` is `LinearAlgebra`.
23+
24+
```julia
25+
using TensorOperations, LinearAlgebra
26+
function LinearAlgebra.permutedims!(C::Array{T,N}, A::StridedArray{T,N}, perm) where {T,N}
27+
if isbitstype(T)
28+
TensorOperations.tensorcopy!(A, ntuple(identity,N), C, perm)
29+
else
30+
invoke(permutedims!, Tuple{Any,AbstractArray,Any}, C, A, perm)
31+
end
32+
end
33+
```

src/GraphTensorNetworks.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,28 @@
11
module GraphTensorNetworks
22

3-
# Write your package code here.
3+
using OMEinsumContractionOrders: OMEinsum
4+
using Core: Argument
5+
using TropicalGEMM, TropicalNumbers
6+
using OMEinsum
7+
using OMEinsum: flatten
48

9+
# patches for OMEinsum
10+
OMEinsum.asarray(x, ::AbstractArray) = fill(x)
11+
OMEinsum.dynamic_einsum(::EinCode{ixs, iy}, xs; kwargs...) where {ixs, iy} = dynamic_einsum(ixs, xs, iy; kwargs...)
12+
13+
project_relative_path(xs...) = normpath(joinpath(dirname(dirname(pathof(@__MODULE__))), xs...))
14+
15+
include("arithematics.jl")
16+
include("independence_polynomial.jl")
17+
include("configurations.jl")
18+
include("graphs.jl")
19+
include("bounding.jl")
20+
include("viz.jl")
21+
include("interfaces.jl")
22+
23+
using Requires
24+
function __init__()
25+
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
526
end
27+
28+
end

src/arithematics.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
export is_commutative_semiring
2+
3+
# this function is used for testing
4+
function is_commutative_semiring(a::T, b::T, c::T) where T
5+
# +
6+
if (a + b) + c != a + (b + c)
7+
@debug "(a + b) + c != a + (b + c)"
8+
return false
9+
end
10+
if !(a + zero(T) == zero(T) + a == a)
11+
@debug "!(a + zero(T) == zero(T) + a == a)"
12+
return false
13+
end
14+
if a + b != b + a
15+
@debug "a + b != b + a"
16+
return false
17+
end
18+
# *
19+
if (a * b) * c != a * (b * c)
20+
@debug "(a * b) * c != a * (b * c)"
21+
return false
22+
end
23+
if !(a * one(T) == one(T) * a == a)
24+
@debug "!(a * one(T) == one(T) * a == a)"
25+
return false
26+
end
27+
if a * b != b * a
28+
@debug "a * b != b * a"
29+
return false
30+
end
31+
# more
32+
if a * (b+c) != a*b + a*c
33+
@debug "a * (b+c) != a*b + a*c"
34+
return false
35+
end
36+
if (a+b) * c != a*c + b*c
37+
@debug "(a+b) * c != a*c + b*c"
38+
return false
39+
end
40+
if !(a * zero(T) == zero(T) * a == zero(T))
41+
@debug "!(a * zero(T) == zero(T) * a == zero(T))"
42+
return false
43+
end
44+
if !(a * zero(T) == zero(T) * a == zero(T))
45+
@debug "!(a * zero(T) == zero(T) * a == zero(T))"
46+
return false
47+
end
48+
return true
49+
end
50+
51+
export Max2Poly
52+
53+
# get maximum two countings (polynomial truncated to largest two orders)
54+
struct Max2Poly{T} <: Number
55+
a::T
56+
b::T
57+
maxorder::Float64
58+
end
59+
60+
function Base.:+(a::Max2Poly, b::Max2Poly)
61+
if a.maxorder == b.maxorder
62+
return Max2Poly(a.a+b.a, a.b+b.b, a.maxorder)
63+
elseif a.maxorder == b.maxorder-1
64+
return Max2Poly(a.b+b.a, b.b, b.maxorder)
65+
elseif a.maxorder == b.maxorder+1
66+
return Max2Poly(a.a+b.b, a.b, a.maxorder)
67+
elseif a.maxorder < b.maxorder
68+
return b
69+
else
70+
return a
71+
end
72+
end
73+
74+
function Base.:*(a::Max2Poly, b::Max2Poly)
75+
maxorder = a.maxorder + b.maxorder
76+
Max2Poly(a.a*b.b + a.b*b.a, a.b * b.b, maxorder)
77+
end
78+
79+
Base.zero(::Type{Max2Poly{T}}) where T = Max2Poly(zero(T), zero(T), -Inf)
80+
Base.one(::Type{Max2Poly{T}}) where T = Max2Poly(zero(T), one(T), 0.0)
81+
Base.zero(::Max2Poly{T}) where T = zero(Max2Poly{T})
82+
Base.one(::Max2Poly{T}) where T = one(Max2Poly{T})
83+

src/bounding.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
using TupleTools
2+
3+
export bounding_contract
4+
5+
Base.isnan(x::Tropical) = isnan(x.n)
6+
function backward_tropical(mode, @nospecialize(ixs), @nospecialize(xs), @nospecialize(iy), @nospecialize(y), @nospecialize(ymask), size_dict)
7+
y .= inv.(y) .* ymask
8+
masks = []
9+
for i=1:length(ixs)
10+
nixs = TupleTools.insertat(ixs, i, (iy,))
11+
nxs = TupleTools.insertat( xs, i, (y,))
12+
niy = ixs[i]
13+
if mode == :all
14+
mask = zeros(Bool, size(xs[i]))
15+
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .== xs[i]
16+
push!(masks, mask)
17+
elseif mode == :single # wrong, need `B` matching `A`.
18+
A = zeros(eltype(xs[i]), size(xs[i]))
19+
A = einsum(EinCode(nixs, niy), nxs, size_dict)
20+
push!(masks, onehotmask(A, xs[i]))
21+
else
22+
error("unkown mode: $mod")
23+
end
24+
end
25+
return masks
26+
end
27+
28+
function onehotmask(A::AbstractArray{T}, X::AbstractArray{T}) where T
29+
@assert length(A) == length(X)
30+
mask = falses(size(A)...)
31+
found = false
32+
@inbounds for j=1:length(A)
33+
if X[j] == inv(A[j]) && !found
34+
mask[j] = true
35+
found = true
36+
else
37+
X[j] = zero(T)
38+
end
39+
end
40+
return mask
41+
end
42+
43+
struct CacheTree{T}
44+
content::AbstractArray{T}
45+
siblings::Vector{CacheTree{T}}
46+
end
47+
function cached_einsum(code::Int, @nospecialize(xs), size_dict)
48+
y = xs[code]
49+
CacheTree(y, CacheTree{eltype(y)}[])
50+
end
51+
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
52+
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
53+
y = einsum(code.eins, (getfield.(caches, :content)...,), size_dict)
54+
CacheTree(y, caches)
55+
end
56+
57+
function generate_masktree(code::Int, cache, mask, size_dict, mode=:all)
58+
CacheTree(mask, CacheTree{Bool}[])
59+
end
60+
function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all)
61+
submasks = backward_tropical(mode, OMEinsum.getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
62+
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
63+
end
64+
65+
function masked_einsum(code::Int, @nospecialize(xs), masks, size_dict)
66+
y = copy(xs[code])
67+
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y))); y
68+
end
69+
function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
70+
xs = [masked_einsum(arg, xs, mask, size_dict) for (arg, mask) in zip(code.args, masks.siblings)]
71+
y = einsum(code.eins, (xs...,), size_dict)
72+
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y))); y
73+
end
74+
75+
function bounding_contract(@nospecialize(code::EinCode), @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
76+
bounding_contract(NestedEinsum((1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
77+
end
78+
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
79+
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), xsa, size_info)
80+
# compute intermediate tensors
81+
@debug "caching einsum..."
82+
c = cached_einsum(code, xsa, size_dict)
83+
# compute masks from cached tensors
84+
@debug "generating masked tree..."
85+
mt = generate_masktree(code, c, ymask, size_dict, :all)
86+
# compute results with masks
87+
masked_einsum(code, xsb, mt, size_dict)
88+
end
89+
90+
function mis_config_ad(@nospecialize(code::EinCode), @nospecialize(xsa), ymask; size_info=nothing)
91+
mis_config_ad(NestedEinsum((1:length(xsa)), code), xsa, ymask; size_info=size_info)
92+
end
93+
94+
function mis_config_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=nothing)
95+
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), xsa, size_info)
96+
# compute intermediate tensors
97+
@debug "caching einsum..."
98+
c = cached_einsum(code, xsa, size_dict)
99+
n = asscalar(c.content)
100+
# compute masks from cached tensors
101+
@debug "generating masked tree..."
102+
mt = generate_masktree(code, c, ymask, size_dict, :single)
103+
n, read_config!(code, mt, Dict())
104+
end
105+
106+
function read_config!(code::NestedEinsum, mt, out)
107+
for (arg, ix, sibling) in zip(code.args, OMEinsum.getixs(code.eins), mt.siblings)
108+
if arg isa Int
109+
assign = convert(Array, sibling.content) # note: the content can be CuArray
110+
if length(ix) == 1
111+
if !assign[1] && assign[2]
112+
out[ix[1]] = 1
113+
elseif !assign[2] && assign[1]
114+
out[ix[1]] = 0
115+
else
116+
error("invalid assign $(assign)")
117+
end
118+
end
119+
else # nested
120+
read_config!(arg, sibling, out)
121+
end
122+
end
123+
return out
124+
end

0 commit comments

Comments
 (0)