@@ -212,7 +212,7 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
212212 } ;
213213 }
214214
215- var broads = SmartBroadcastGradientArgs ( x , y ) ;
215+ var broads = SmartBroadcastGradientArgs ( x , y , grad ) ;
216216 var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
217217 var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
218218
@@ -468,7 +468,7 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
468468 _ShapesFullySpecifiedAndEqual ( x , y , grad ) )
469469 return new Tensor [ ] { grad , - grad } ;
470470
471- var broads = SmartBroadcastGradientArgs ( x , y ) ;
471+ var broads = SmartBroadcastGradientArgs ( x , y , grad ) ;
472472 var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
473473 var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
474474
@@ -718,7 +718,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
718718
719719 var z = op . outputs [ 0 ] ;
720720
721- var broads = SmartBroadcastGradientArgs ( x , y ) ;
721+ var broads = SmartBroadcastGradientArgs ( x , y , grad ) ;
722722 var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
723723 var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
724724
@@ -753,7 +753,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
753753 /// <param name="x"></param>
754754 /// <param name="y"></param>
755755 /// <returns></returns>
756- private static ( Tensor , Tensor , bool ) [ ] SmartBroadcastGradientArgs ( Tensor x , Tensor y )
756+ private static ( Tensor , Tensor , bool ) [ ] SmartBroadcastGradientArgs ( Tensor x , Tensor y , Tensor grad )
757757 {
758758 Tensor sx , sy ;
759759 if ( x . TensorShape . is_fully_defined ( ) &&
@@ -771,8 +771,8 @@ private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Ten
771771 var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
772772 return new [ ]
773773 {
774- ( sx , rx , true ) ,
775- ( sy , ry , true )
774+ ( sx , rx , ! x . TensorShape . Equals ( grad . TensorShape ) ) ,
775+ ( sy , ry , ! y . TensorShape . Equals ( grad . TensorShape ) )
776776 } ;
777777 }
778778 }
0 commit comments