11using TupleTools
2+ using OMEinsum: DynamicEinCode
23
34"""
45 backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
@@ -10,12 +11,12 @@ The backward rule for tropical einsum.
1011* `ymask` is the boolean mask for gradients,
1112* `size_dict` is a key-value map from tensor label to dimension size.
1213"""
13- function backward_tropical (mode, @nospecialize ( ixs) , @nospecialize (xs), @nospecialize (iy) , @nospecialize (y), @nospecialize (ymask), size_dict)
14+ function backward_tropical (mode, ixs, @nospecialize (xs:: Tuple ), iy , @nospecialize (y), @nospecialize (ymask), size_dict)
1415 y .= inv .(y) .* ymask
1516 masks = []
1617 for i= 1 : length (ixs)
17- nixs = TupleTools . insertat (ixs, i, (iy,) )
18- nxs = TupleTools . insertat ( xs, i, (y,) )
18+ nixs = OMEinsum . _insertat (ixs, i, iy )
19+ nxs = OMEinsum . _insertat ( xs, i, y )
1920 niy = ixs[i]
2021 if mode == :all
2122 mask = zeros (Bool, size (xs[i]))
@@ -53,34 +54,39 @@ struct CacheTree{T}
5354 content:: AbstractArray{T}
5455 siblings:: Vector{CacheTree{T}}
5556end
56- function cached_einsum (code:: Int , @nospecialize (xs), size_dict)
57- y = xs[code]
58- CacheTree (y, CacheTree{eltype (y)}[])
59- end
6057function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
61- caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
62- y = code. eins (getfield .(caches, :content )... ; size_info= size_dict)
63- CacheTree (y, caches)
58+ if OMEinsum. isleaf (code)
59+ y = xs[code. tensorindex]
60+ return CacheTree (y, CacheTree{eltype (y)}[])
61+ else
62+ caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
63+ y = einsum (code. eins, ntuple (i-> caches[i]. content, length (caches)), size_dict)
64+ return CacheTree (y, caches)
65+ end
6466end
6567
6668# computed mask tree by back propagation
67- function generate_masktree (code:: Int , cache, mask, size_dict, mode= :all )
68- CacheTree (mask, CacheTree{Bool}[])
69- end
7069function generate_masktree (code:: NestedEinsum , cache, mask, size_dict, mode= :all )
71- submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
72- return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode))
70+ if OMEinsum. isleaf (code)
71+ return CacheTree (mask, CacheTree{Bool}[])
72+ else
73+ submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
74+ return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode))
75+ end
7376end
7477
7578# The masked einsum contraction
76- function masked_einsum (code:: Int , @nospecialize (xs), masks, size_dict)
77- y = copy (xs[code])
78- y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
79- end
8079function masked_einsum (code:: NestedEinsum , @nospecialize (xs), masks, size_dict)
81- xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code. args, masks. siblings)]
82- y = einsum (code. eins, (xs... ,), size_dict)
83- y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
80+ if OMEinsum. isleaf (code)
81+ y = copy (xs[code. tensorindex])
82+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
83+ return y
84+ else
85+ xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code. args, masks. siblings)]
86+ y = einsum (code. eins, (xs... ,), size_dict)
87+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
88+ return y
89+ end
8490end
8591
8692"""
@@ -92,8 +98,9 @@ Contraction method with bounding.
9298 * `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
9399 * `ymask` is the initial gradient mask for the output tensor.
94100"""
95- function bounding_contract (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
96- bounding_contract (NestedEinsum ((1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
101+ function bounding_contract (code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
102+ LT = OMEinsum. labeltype (code)
103+ bounding_contract (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
97104end
98105function bounding_contract (code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
99106 size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins),Int} () : copy (size_info)
@@ -109,8 +116,9 @@ function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospe
109116end
110117
111118# get the optimal solution with automatic differentiation.
112- function solution_ad (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask; size_info= nothing )
113- solution_ad (NestedEinsum ((1 : length (xsa)), code), xsa, ymask; size_info= size_info)
119+ function solution_ad (code:: EinCode , @nospecialize (xsa), ymask; size_info= nothing )
120+ LT = OMEinsum. labeltype (code)
121+ solution_ad (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
114122end
115123
116124function solution_ad (code:: NestedEinsum , @nospecialize (xsa), ymask; size_info= nothing )
128136
129137function read_config! (code:: NestedEinsum , mt, out)
130138 for (arg, ix, sibling) in zip (code. args, getixs (code. eins), mt. siblings)
131- if arg isa Int
139+ if OMEinsum . isleaf ( arg)
132140 assign = convert (Array, sibling. content) # note: the content can be CuArray
133141 if length (ix) == 1
134142 if ! assign[1 ] && assign[2 ]
0 commit comments