1+ using TupleTools
2+
3+ export bounding_contract
4+
5+ Base. isnan (x:: Tropical ) = isnan (x. n)
6+ function backward_tropical (mode, @nospecialize (ixs), @nospecialize (xs), @nospecialize (iy), @nospecialize (y), @nospecialize (ymask), size_dict)
7+ y .= inv .(y) .* ymask
8+ masks = []
9+ for i= 1 : length (ixs)
10+ nixs = TupleTools. insertat (ixs, i, (iy,))
11+ nxs = TupleTools. insertat ( xs, i, (y,))
12+ niy = ixs[i]
13+ if mode == :all
14+ mask = zeros (Bool, size (xs[i]))
15+ mask .= inv .(einsum (EinCode (nixs, niy), nxs, size_dict)) .== xs[i]
16+ push! (masks, mask)
17+ elseif mode == :single # wrong, need `B` matching `A`.
18+ A = zeros (eltype (xs[i]), size (xs[i]))
19+ A = einsum (EinCode (nixs, niy), nxs, size_dict)
20+ push! (masks, onehotmask (A, xs[i]))
21+ else
22+ error (" unkown mode: $mod " )
23+ end
24+ end
25+ return masks
26+ end
27+
28+ function onehotmask (A:: AbstractArray{T} , X:: AbstractArray{T} ) where T
29+ @assert length (A) == length (X)
30+ mask = falses (size (A)... )
31+ found = false
32+ @inbounds for j= 1 : length (A)
33+ if X[j] == inv (A[j]) && ! found
34+ mask[j] = true
35+ found = true
36+ else
37+ X[j] = zero (T)
38+ end
39+ end
40+ return mask
41+ end
42+
43+ struct CacheTree{T}
44+ content:: AbstractArray{T}
45+ siblings:: Vector{CacheTree{T}}
46+ end
47+ function cached_einsum (code:: Int , @nospecialize (xs), size_dict)
48+ y = xs[code]
49+ CacheTree (y, CacheTree{eltype (y)}[])
50+ end
51+ function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
52+ caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
53+ y = einsum (code. eins, (getfield .(caches, :content )... ,), size_dict)
54+ CacheTree (y, caches)
55+ end
56+
57+ function generate_masktree (code:: Int , cache, mask, size_dict, mode= :all )
58+ CacheTree (mask, CacheTree{Bool}[])
59+ end
60+ function generate_masktree (code:: NestedEinsum , cache, mask, size_dict, mode= :all )
61+ submasks = backward_tropical (mode, OMEinsum. getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
62+ return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode))
63+ end
64+
65+ function masked_einsum (code:: Int , @nospecialize (xs), masks, size_dict)
66+ y = copy (xs[code])
67+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
68+ end
69+ function masked_einsum (code:: NestedEinsum , @nospecialize (xs), masks, size_dict)
70+ xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code. args, masks. siblings)]
71+ y = einsum (code. eins, (xs... ,), size_dict)
72+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
73+ end
74+
75+ function bounding_contract (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
76+ bounding_contract (NestedEinsum ((1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
77+ end
78+ function bounding_contract (code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
79+ size_dict = OMEinsum. get_size_dict (getixs (flatten (code)), xsa, size_info)
80+ # compute intermediate tensors
81+ @debug " caching einsum..."
82+ c = cached_einsum (code, xsa, size_dict)
83+ # compute masks from cached tensors
84+ @debug " generating masked tree..."
85+ mt = generate_masktree (code, c, ymask, size_dict, :all )
86+ # compute results with masks
87+ masked_einsum (code, xsb, mt, size_dict)
88+ end
89+
90+ function mis_config_ad (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask; size_info= nothing )
91+ mis_config_ad (NestedEinsum ((1 : length (xsa)), code), xsa, ymask; size_info= size_info)
92+ end
93+
94+ function mis_config_ad (code:: NestedEinsum , @nospecialize (xsa), ymask; size_info= nothing )
95+ size_dict = OMEinsum. get_size_dict (getixs (flatten (code)), xsa, size_info)
96+ # compute intermediate tensors
97+ @debug " caching einsum..."
98+ c = cached_einsum (code, xsa, size_dict)
99+ n = asscalar (c. content)
100+ # compute masks from cached tensors
101+ @debug " generating masked tree..."
102+ mt = generate_masktree (code, c, ymask, size_dict, :single )
103+ n, read_config! (code, mt, Dict ())
104+ end
105+
106+ function read_config! (code:: NestedEinsum , mt, out)
107+ for (arg, ix, sibling) in zip (code. args, OMEinsum. getixs (code. eins), mt. siblings)
108+ if arg isa Int
109+ assign = convert (Array, sibling. content) # note: the content can be CuArray
110+ if length (ix) == 1
111+ if ! assign[1 ] && assign[2 ]
112+ out[ix[1 ]] = 1
113+ elseif ! assign[2 ] && assign[1 ]
114+ out[ix[1 ]] = 0
115+ else
116+ error (" invalid assign $(assign) " )
117+ end
118+ end
119+ else # nested
120+ read_config! (arg, sibling, out)
121+ end
122+ end
123+ return out
124+ end
0 commit comments