11using OMEinsum: DynamicEinCode
22
3+ struct AllConfigs{K} end
4+ largest_k (:: AllConfigs{K} ) where K = K
5+ struct SingleConfig end
6+
37"""
48 backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
59
@@ -17,11 +21,11 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
1721 nixs = OMEinsum. _insertat (ixs, i, iy)
1822 nxs = OMEinsum. _insertat ( xs, i, y)
1923 niy = ixs[i]
20- if mode == :all
24+ if mode isa AllConfigs
2125 mask = zeros (Bool, size (xs[i]))
22- mask .= inv .(einsum (EinCode (nixs, niy), nxs, size_dict)) .== xs[i]
26+ mask .= inv .(einsum (EinCode (nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical ( largest_k (mode) - 1 )
2327 push! (masks, mask)
24- elseif mode == :single # wrong, need `B` matching `A`.
28+ elseif mode isa SingleConfig
2529 A = zeros (eltype (xs[i]), size (xs[i]))
2630 A = einsum (EinCode (nixs, niy), nxs, size_dict)
2731 push! (masks, onehotmask (A, xs[i]))
@@ -65,12 +69,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6569end
6670
6771# computed mask tree by back propagation
68- function generate_masktree (code:: NestedEinsum , cache, mask, size_dict, mode = :all )
72+ function generate_masktree (mode, code:: NestedEinsum , cache, mask, size_dict)
6973 if OMEinsum. isleaf (code)
7074 return CacheTree (mask, CacheTree{Bool}[])
7175 else
7276 submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
73- return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode ))
77+ return CacheTree (mask, generate_masktree .(Ref (mode), code. args, cache. siblings, submasks, Ref (size_dict)))
7478 end
7579end
7680
@@ -89,27 +93,28 @@ function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
8993end
9094
9195"""
92- bounding_contract(code, xsa, ymask, xsb; size_info=nothing)
96+ bounding_contract(mode, code, xsa, ymask, xsb; size_info=nothing)
9397
9498Contraction method with bounding.
9599
100+ * `mode` is a `AllConfigs{K}` instance, where `MIS-K+1` is the largest IS size that you care about.
96101 * `xsa` are input tensors for bounding, e.g. tropical tensors,
97102 * `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
98103 * `ymask` is the initial gradient mask for the output tensor.
99104"""
100- function bounding_contract (code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
105+ function bounding_contract (mode :: AllConfigs , code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
101106 LT = OMEinsum. labeltype (code)
102- bounding_contract (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
107+ bounding_contract (mode, NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
103108end
104- function bounding_contract (code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
109+ function bounding_contract (mode :: AllConfigs , code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
105110 size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins),Int} () : copy (size_info)
106111 OMEinsum. get_size_dict! (code, xsa, size_dict)
107112 # compute intermediate tensors
108113 @debug " caching einsum..."
109114 c = cached_einsum (code, xsa, size_dict)
110115 # compute masks from cached tensors
111116 @debug " generating masked tree..."
112- mt = generate_masktree (code, c, ymask, size_dict, :all )
117+ mt = generate_masktree (mode, code, c, ymask, size_dict)
113118 # compute results with masks
114119 masked_einsum (code, xsb, mt, size_dict)
115120end
@@ -129,7 +134,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
129134 n = asscalar (c. content)
130135 # compute masks from cached tensors
131136 @debug " generating masked tree..."
132- mt = generate_masktree (code, c, ymask, size_dict, :single )
137+ mt = generate_masktree (SingleConfig (), code, c, ymask, size_dict)
133138 n, read_config! (code, mt, Dict ())
134139end
135140
0 commit comments