@@ -64,11 +64,11 @@ function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
6464end
6565function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
6666 if OMEinsum. isleaf (code)
67- y = xs[code . tensorindex]
67+ y = xs[OMEinsum . tensorindex (code) ]
6868 return CacheTree (y, CacheTree{eltype (y)}[])
6969 else
70- caches = [cached_einsum (arg, xs, size_dict) for arg in code . args ]
71- y = einsum (code . eins , ntuple (i-> caches[i]. content, length (caches)), size_dict)
70+ caches = [cached_einsum (arg, xs, size_dict) for arg in OMEinsum . siblings (code) ]
71+ y = einsum (OMEinsum . rootcode (code) , ntuple (i-> caches[i]. content, length (caches)), size_dict)
7272 return CacheTree (y, caches)
7373 end
7474end
@@ -84,8 +84,9 @@ function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
8484 if OMEinsum. isleaf (code)
8585 return CacheTree (mask, CacheTree{Bool}[])
8686 else
87- submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
88- return CacheTree (mask, generate_masktree .(Ref (mode), code. args, cache. siblings, submasks, Ref (size_dict)))
87+ eins = OMEinsum. rootcode (code)
88+ submasks = backward_tropical (mode, getixs (eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (eins), cache. content, mask, size_dict)
89+ return CacheTree (mask, generate_masktree .(Ref (mode), OMEinsum. siblings (code), cache. siblings, submasks, Ref (size_dict)))
8990 end
9091end
9192
@@ -98,12 +99,12 @@ function masked_einsum(se::SlicedEinsum, @nospecialize(xs), masks, size_dict)
9899end
99100function masked_einsum (code:: NestedEinsum , @nospecialize (xs), masks, size_dict)
100101 if OMEinsum. isleaf (code)
101- y = copy (xs[code . tensorindex])
102+ y = copy (xs[OMEinsum . tensorindex (code) ])
102103 y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
103104 return y
104105 else
105- xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code . args , masks. siblings)]
106- y = einsum (code . eins , (xs... ,), size_dict)
106+ xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (OMEinsum . siblings (code) , masks. siblings)]
107+ y = einsum (OMEinsum . rootcode (code) , (xs... ,), size_dict)
107108 y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
108109 return y
109110 end
@@ -121,10 +122,10 @@ Contraction method with bounding.
121122"""
122123function bounding_contract (mode:: AllConfigs , code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
123124 LT = OMEinsum. labeltype (code)
124- bounding_contract (mode, NestedEinsum ( NestedEinsum {DynamicEinCode{LT} } .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
125+ bounding_contract (mode, DynamicNestedEinsum ( DynamicNestedEinsum {LT } .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
125126end
126127function bounding_contract (mode:: AllConfigs , code:: Union{NestedEinsum,SlicedEinsum} , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
127- size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins ),Int} () : copy (size_info)
128+ size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code),Int} () : copy (size_info)
128129 OMEinsum. get_size_dict! (code, xsa, size_dict)
129130 # compute intermediate tensors
130131 @debug " caching einsum..."
@@ -139,11 +140,11 @@ end
139140# get the optimal solution with automatic differentiation.
140141function solution_ad (code:: EinCode , @nospecialize (xsa), ymask; size_info= nothing )
141142 LT = OMEinsum. labeltype (code)
142- solution_ad (NestedEinsum ( NestedEinsum {DynamicEinCode{LT} } .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
143+ solution_ad (DynamicNestedEinsum ( DynamicNestedEinsum {LT } .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
143144end
144145
145146function solution_ad (code:: Union{NestedEinsum,SlicedEinsum} , @nospecialize (xsa), ymask; size_info= nothing )
146- size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins ),Int} () : copy (size_info)
147+ size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code),Int} () : copy (size_info)
147148 OMEinsum. get_size_dict! (code, xsa, size_dict)
148149 # compute intermediate tensors
149150 @debug " caching einsum..."
@@ -165,7 +166,7 @@ function read_config!(code::SlicedEinsum, mt, out)
165166end
166167
167168function read_config! (code:: NestedEinsum , mt, out)
168- for (arg, ix, sibling) in zip (code . args , getixs (code . eins ), mt. siblings)
169+ for (arg, ix, sibling) in zip (OMEinsum . siblings (code) , getixs (OMEinsum . rootcode (code) ), mt. siblings)
169170 if OMEinsum. isleaf (arg)
170171 mask = convert (Array, sibling. content) # note: the content can be CuArray
171172 for ci in CartesianIndices (mask)
0 commit comments