Skip to content

Commit 472cf42

Browse files
committed
update
1 parent 374bcb3 commit 472cf42

File tree

5 files changed

+41
-9
lines changed

5 files changed

+41
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Compose = "0.9"
2828
FFTW = "1.4"
2929
LightGraphs = "1.3"
3030
Mods = "1.3"
31-
OMEinsum = "0.4"
31+
OMEinsum = "0.4, 0.5"
3232
OMEinsumContractionOrders = "0.2, 0.3, 0.4"
3333
Polynomials = "2.0"
3434
Primes = "0.5"

src/GraphTensorNetworks.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,38 @@ export timespace_complexity, @ein_str
1111

1212
project_relative_path(xs...) = normpath(joinpath(dirname(dirname(pathof(@__MODULE__))), xs...))
1313

14+
# patch to permutedims
15+
using Base.Cartesian
16+
using Base: size_to_strides, checkdims_perm
17+
for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)]
18+
@eval @generated function Base.permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...)
19+
quote
20+
checkdims_perm(P, B, perm)
21+
22+
#calculates all the strides
23+
native_strides = size_to_strides(1, size(B)...)
24+
strides_1 = 0
25+
@nexprs $N d->(strides_{d+1} = native_strides[perm[d]])
26+
27+
#Creates offset, because indexing starts at 1
28+
offset = 1 - sum(@ntuple $N d->strides_{d+1})
29+
30+
sumc = 0
31+
ind = 1
32+
@nexprs 1 d->(counts_{$N+1} = strides_{$N+1}) # a trick to set counts_($N+1)
33+
@nloops($N, i, P,
34+
d->(df_d=i_d*strides_{d+1} ;sumc += df_d), # PRE
35+
d->(sumc -= df_d), # POST
36+
begin # BODY
37+
@inbounds P[ind] = B[sumc+offset]
38+
ind += 1
39+
end)
40+
41+
return P
42+
end
43+
end
44+
end
45+
1446
include("bitvector.jl")
1547
include("arithematics.jl")
1648
include("networks.jl")
@@ -26,4 +58,4 @@ function __init__()
2658
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
2759
end
2860

29-
end
61+
end

src/bounding.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function cached_einsum(code::Int, @nospecialize(xs), size_dict)
5959
end
6060
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6161
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
62-
y = dynamic_einsum(code.eins, (getfield.(caches, :content)...,); size_info=size_dict)
62+
y = code.eins(getfield.(caches, :content)...; size_info=size_dict)
6363
CacheTree(y, caches)
6464
end
6565

@@ -96,7 +96,7 @@ function bounding_contract(@nospecialize(code::EinCode), @nospecialize(xsa), yma
9696
bounding_contract(NestedEinsum((1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
9797
end
9898
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
99-
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), xsa, size_info)
99+
size_dict = OMEinsum.get_size_dict(getixs(flatten(code)), (xsa...,), size_info)
100100
# compute intermediate tensors
101101
@debug "caching einsum..."
102102
c = cached_einsum(code, xsa, size_dict)
@@ -142,4 +142,4 @@ function read_config!(code::NestedEinsum, mt, out)
142142
end
143143
end
144144
return out
145-
end
145+
end

src/graph_polynomials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function contractf(f, gp::GraphProblem; usecuda=false)
9797
if usecuda
9898
xs = CuArray.(xs)
9999
end
100-
dynamic_einsum(gp.code, xs)
100+
gp.code(xs...)
101101
end
102102

103103
############### Problem specific implementations ################

src/viz.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ function vizeinsum(nodes, edges; config=zeros(Int, length(nodes)), unit=1.0, gra
6565
zoom_into(img, XMIN, XMAX, YMIN, YMAX; graphsize=graphsize, rescale=rescale)
6666
end
6767

68-
function vizeinsum(::EinCode{ixs, iy}, locs::AbstractVector{<:Pair}; kwargs...) where {ixs, iy}
69-
vizeinsum(ixs, iy, Dict(locs); kwargs...)
68+
function vizeinsum(code::EinCode, locs::AbstractVector{<:Pair}; kwargs...)
69+
vizeinsum(getixs(code), getiy(code), Dict(locs); kwargs...)
7070
end
7171
function vizeinsum(code::NestedEinsum, locs::AbstractVector{<:Pair}; kwargs...)
7272
vizeinsum(flatten(code), locs; kwargs...)
@@ -76,4 +76,4 @@ function vizeinsum(ixs::Tuple, iy::Tuple, locs::Dict; kwargs...)
7676
nodes = [l=>locs[l] for l in legs]
7777
edges = [map(i->findfirst(==(i), legs), ix) for ix in ixs]
7878
vizeinsum(nodes, edges; config=[l iy for l in legs], kwargs...)
79-
end
79+
end

0 commit comments

Comments
 (0)