@@ -212,26 +212,29 @@ public static Tensor __case__(Tensor x, TF_DataType dtype, string name = null)
212212 throw new NotImplementedException ( ) ;
213213 }
214214
215- public static Tensor reduce_sum ( Tensor input_tensor , Tensor axis = null , bool keepdims = false )
215+ public static Tensor reduce_sum ( Tensor input_tensor , Tensor axis = null , bool keepdims = false , string name = null )
216216 {
217217 var r = _ReductionDims ( input_tensor , axis ) ;
218- var m = gen_math_ops . sum ( input_tensor , r ) ;
219- return _may_reduce_to_scalar ( keepdims , m ) ;
218+ var m = gen_math_ops . _sum ( input_tensor , r , keep_dims : keepdims , name : name ) ;
219+ return _may_reduce_to_scalar ( keepdims , axis , m ) ;
220220 }
221221
222222 public static Tensor reduce_sum ( Tensor input_tensor , int axis , bool keepdims = false )
223223 {
224- var m = gen_math_ops . sum ( input_tensor , axis ) ;
225- return _may_reduce_to_scalar ( keepdims , m ) ;
224+ var m = gen_math_ops . _sum ( input_tensor , axis ) ;
225+ return _may_reduce_to_scalar ( keepdims , new int [ ] { axis } , m ) ;
226226 }
227227
228- private static Tensor _may_reduce_to_scalar ( bool keepdims , Tensor output )
228+ private static Tensor _may_reduce_to_scalar ( bool keepdims , Tensor axis , Tensor output )
229229 {
230- output . shape = new long [ 0 ] ;
230+ if ( ! common_shapes . has_fully_defined_shape ( output ) &&
231+ ! keepdims &&
232+ axis == null )
233+ output . shape = new long [ 0 ] ;
231234 return output ;
232235 }
233236
234- private static Tensor _may_reduce_to_scalar ( bool keepdims , int [ ] axos , Tensor output )
237+ private static Tensor _may_reduce_to_scalar ( bool keepdims , int [ ] axis , Tensor output )
235238 {
236239 output . shape = new long [ 0 ] ;
237240 return output ;
0 commit comments