@@ -22,8 +22,10 @@ public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
2222 var sy = array_ops . shape ( y ) ;
2323 var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
2424
25- var r1 = gen_array_ops . reshape ( math_ops . reduce_sum ( grad , rx ) , sx ) ;
26- var r2 = gen_array_ops . reshape ( math_ops . reduce_sum ( grad , ry ) , sy ) ;
25+ var sum1 = math_ops . reduce_sum ( grad , rx ) ;
26+ var r1 = gen_array_ops . reshape ( sum1 , sx ) ;
27+ var sum2 = math_ops . reduce_sum ( grad , ry ) ;
28+ var r2 = gen_array_ops . reshape ( sum2 , sy ) ;
2729
2830 return new Tensor [ ] { r1 , r2 } ;
2931 }
@@ -48,7 +50,8 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
4850 var x = op . inputs [ 0 ] ;
4951 var y = op . inputs [ 1 ] ;
5052 var grad = grads [ 0 ] ;
51- if ( grad is Tensor && _ShapesFullySpecifiedAndEqual ( x , y , grad ) &&
53+ if ( grad is Tensor &&
54+ _ShapesFullySpecifiedAndEqual ( x , y , grad ) &&
5255 new TF_DataType [ ] { tf . int32 , tf . float32 } . Contains ( grad . dtype ) )
5356 return new Tensor [ ] { gen_math_ops . mul ( grad , y ) , gen_math_ops . mul ( grad , x ) } ;
5457
@@ -60,10 +63,11 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
6063 y = math_ops . conj ( y ) ;
6164
6265 var mul1 = gen_math_ops . mul ( grad , y ) ;
63- var mul2 = gen_math_ops . mul ( x , grad ) ;
6466 var reduce_sum1 = math_ops . reduce_sum ( mul1 , rx ) ;
65- var reduce_sum2 = math_ops . reduce_sum ( mul2 , ry ) ;
6667 var reshape1 = gen_array_ops . reshape ( reduce_sum1 , sx ) ;
68+
69+ var mul2 = gen_math_ops . mul ( x , grad ) ;
70+ var reduce_sum2 = math_ops . reduce_sum ( mul2 , ry ) ;
6771 var reshape2 = gen_array_ops . reshape ( reduce_sum2 , sy ) ;
6872
6973 return new Tensor [ ] { reshape1 , reshape2 } ;
@@ -146,7 +150,13 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
146150
147151 public static bool _ShapesFullySpecifiedAndEqual ( Tensor x , Tensor y , Tensor grad )
148152 {
149- return x . NDims == y . NDims && y . NDims == grad . NDims && x . NDims > - 1 ;
153+ var x_shape = x . _shape_tuple ( ) ;
154+ var y_shape = y . _shape_tuple ( ) ;
155+ var grad_shape = grad . _shape_tuple ( ) ;
156+ return Enumerable . SequenceEqual ( x_shape , y_shape ) &&
157+ Enumerable . SequenceEqual ( y_shape , grad_shape ) &&
158+ x . NDims != - 1 &&
159+ ! x_shape . Contains ( - 1 ) ;
150160 }
151161
152162 public static Tensor [ ] _SumGrad ( Operation op , Tensor [ ] grads )
0 commit comments