@@ -117,6 +117,137 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
117117 } ;
118118 }
119119
120+ public static string ellipsis = "..." ;
121+ [ RegisterGradient ( "Einsum" ) ]
122+ public static Tensor [ ] _EinsumGrad ( Operation op , Tensor [ ] grads )
123+ {
124+ // Gradient for Einsum.
125+ string equation = ( string ) op . get_attr ( "equation" ) ;
126+ string [ ] split_equation = equation . Split ( new string [ ] { "->" } , StringSplitOptions . None ) ;
127+ var input_subs = split_equation [ 0 ] ;
128+ var output_subs = split_equation [ 1 ] ;
129+
130+ if ( op . inputs . Length == 1 )
131+ {
132+ var input_shape = array_ops . shape ( op . inputs [ 0 ] ) ;
133+ var reduced_label_set = new HashSet < char > ( new HashSet < char > ( input_subs ) . Except ( new HashSet < char > ( output_subs + ellipsis ) ) ) ;
134+ if ( reduced_label_set . Count == 0 )
135+ return new Tensor [ ] { math_ops . einsum ( string . Format ( "{0}->{1}" , output_subs , input_subs ) , new Tensors ( grads ) ) } ;
136+ return new Tensor [ ] { _GetGradReduced ( new Tensors ( grads ) , output_subs , input_subs , input_shape , reduced_label_set ) } ;
137+ }
138+
139+ string [ ] split_input_subs = input_subs . Split ( new string [ ] { "," } , StringSplitOptions . None ) ;
140+ var x_subs = split_input_subs [ 0 ] ;
141+ var y_subs = split_input_subs [ 1 ] ;
142+ // Add ellipsis for broadcasted dimensions if any operand does not have it.
143+ // This is because the equation "...ij,jk->ik" may be valid if the 0th input's
144+ // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
145+ // because only the output subscripts contain ellipsis.
146+ if ( output_subs . Contains ( ellipsis ) )
147+ {
148+ if ( ! x_subs . Contains ( ellipsis ) )
149+ x_subs += ellipsis ;
150+ if ( ! y_subs . Contains ( ellipsis ) )
151+ y_subs += ellipsis ;
152+ }
153+ // Obtain the gradients wrt the inputs x and y, without taking into account
154+ // the unbroadcasting.
155+ var x = op . inputs [ 0 ] ;
156+ var y = op . inputs [ 1 ] ;
157+ if ( grads . GetDataType ( ) . is_complex ( ) )
158+ {
159+ x = math_ops . conj ( x ) ;
160+ y = math_ops . conj ( y ) ;
161+ }
162+
163+ var x_shape = array_ops . shape ( x ) ;
164+ var y_shape = array_ops . shape ( y ) ;
165+ var grad_x = _GetGradWrt ( grads , y , x_shape , x_subs , y_subs , output_subs ) ;
166+ var grad_y = _GetGradWrt ( grads , x , y_shape , y_subs , x_subs , output_subs ) ;
167+
168+ if ( ! output_subs . Contains ( ellipsis ) )
169+ return new Tensor [ ] { grad_x , grad_y } ;
170+ var bx = _GetBcastSubshape ( x_subs ) ;
171+ int bx_start = bx [ 0 ] , bx_end = bx [ 1 ] ;
172+ var by = _GetBcastSubshape ( y_subs ) ;
173+ int by_start = by [ 0 ] , by_end = by [ 1 ] ;
174+
175+ var x_shape_static = x . shape ;
176+ var y_shape_static = y . shape ;
177+ if ( x_shape_static . IsFullyDefined &&
178+ y_shape_static . IsFullyDefined &&
179+ x_shape_static [ string . Format ( "{0}:{1}" , bx_start , bx_end ) ] == y_shape_static [ string . Format ( "{0}:{1}" , by_start , by_end ) ] )
180+ return new Tensor [ ] { grad_x , grad_y } ;
181+
182+ var r = gen_array_ops . broadcast_gradient_args ( x_shape [ string . Format ( "{0}:{1}" , bx_start , bx_end ) ] ,
183+ y_shape [ string . Format ( "{0}:{1}" , by_start , by_end ) ] ) ;
184+ var rx = r [ 0 ] ;
185+ var ry = r [ 1 ] ;
186+ grad_x = array_ops . reshape ( math_ops . reduce_sum ( grad_x , bx_start + rx ) , x_shape ) ;
187+ grad_y = array_ops . reshape ( math_ops . reduce_sum ( grad_y , by_start + ry ) , y_shape ) ;
188+ return new Tensor [ ] { grad_x , grad_y } ;
189+ }
190+ protected static Tensor _GetGradWrt ( Tensor [ ] output_grads , Tensor other_operand , Tensor input_shape ,
191+ string input_subs , string other_subs , string output_subs )
192+ {
193+ var reduced_label_set = new HashSet < char > ( new HashSet < char > ( input_subs ) . Except ( new HashSet < char > ( output_subs + other_subs + "." ) ) ) ;
194+ var left_subs = string . Join ( "" , input_subs . Where ( s => ! reduced_label_set . Contains ( s ) ) ) ;
195+ var grad_reduced = math_ops . einsum ( string . Format ( "{0},{1}->{2}" , output_subs , other_subs , left_subs ) , new Tensors ( ( Tensors ) output_grads , other_operand ) ) ;
196+ if ( reduced_label_set . Count == 0 )
197+ return grad_reduced ;
198+ return _GetGradReduced ( grad_reduced , left_subs , input_subs , input_shape , reduced_label_set ) ;
199+ }
200+ protected static Tensor _GetGradReduced ( Tensor output_grad , string output_subs , string input_subs , Tensor input_shape , HashSet < char > reduced_label_set )
201+ {
202+ string reduced_subs ;
203+ Tensor reduced_dims ;
204+ List < int > reduced_axes ;
205+ _GetReducedSubscripts ( reduced_label_set , input_shape , input_subs , out reduced_subs , out reduced_dims , out reduced_axes ) ;
206+ bool has_repeated_labels = (
207+ new HashSet < char > ( input_subs ) . Count + new HashSet < char > ( output_subs ) . Count <
208+ input_subs . Length + output_subs . Length ) ;
209+ var input_subs_without_reduced_labels = string . Join ( "" , input_subs . Where ( s => ! reduced_label_set . Contains ( s ) ) ) ;
210+
211+ if ( ! has_repeated_labels && input_subs_without_reduced_labels == output_subs )
212+ {
213+ var reduced_shape = math_ops . reduced_shape ( input_shape , ops . convert_to_tensor ( reduced_axes ) ) ;
214+ return gen_array_ops . broadcast_to ( array_ops . reshape ( output_grad , reduced_shape ) , input_shape ) ;
215+ }
216+ else
217+ {
218+ var grad_shape_with_reduced_labels = array_ops . concat ( new Tensor [ ] { reduced_dims , array_ops . shape ( new Tensors ( output_grad ) ) } , axis : 0 ) ;
219+ var reduced_shape = array_ops . concat ( new Tensor [ ] { array_ops . ones ( reduced_label_set . Count , dtype : dtypes . int32 ) , array_ops . shape ( new Tensors ( output_grad ) ) } , axis : 0 ) ;
220+ var broadcasted_grad = gen_array_ops . broadcast_to ( array_ops . reshape ( output_grad , reduced_shape ) , grad_shape_with_reduced_labels ) ;
221+ return math_ops . einsum ( string . Format ( "{0}->{1}" , reduced_subs + output_subs , input_subs ) , new Tensors ( broadcasted_grad ) ) ;
222+ }
223+ }
224+ protected static void _GetReducedSubscripts ( HashSet < char > reduced_label_set , Tensor input_shape , string subscripts , out string reduced_subs , out Tensor reduced_dims , out List < int > reduced_axes )
225+ {
226+ reduced_subs = string . Join ( "" , reduced_label_set . Select ( c => c . ToString ( ) ) ) ;
227+ reduced_axes = reduced_subs . Select ( s => _GetAxisFromLabel ( subscripts , s ) ) . ToList ( ) ;
228+ reduced_dims = array_ops . stack ( reduced_axes . Select ( ax => input_shape [ ax ] ) . ToList ( ) ) ;
229+ }
230+ protected static int _GetAxisFromLabel ( string subscripts , char label )
231+ {
232+ var splits = subscripts . Split ( new string [ ] { ellipsis } , StringSplitOptions . None ) ;
233+ var index = splits [ 0 ] . IndexOf ( label ) ;
234+ if ( index != - 1 ) return index ;
235+ if ( splits . Length < 2 ) throw new OutOfRangeError ( ) ;
236+ index = splits [ 1 ] . IndexOf ( label ) ;
237+ if ( index != - 1 ) return index ;
238+ throw new ValueError ( ) ;
239+ }
240+ protected static int [ ] _GetBcastSubshape ( string subscripts )
241+ {
242+ int start = subscripts . IndexOf ( ellipsis ) ;
243+ if ( start == - 1 ) return new int [ ] { 0 , 0 } ;
244+ int remaining = subscripts . Length - ( start + ellipsis . Length ) ;
245+ int end ;
246+ if ( remaining > 0 ) end = remaining ;
247+ else throw new Exception ( ) ;
248+ return new int [ ] { start , end } ;
249+ }
250+
120251 /// <summary>
121252 /// Returns grad * exp(x).
122253 /// </summary>
0 commit comments