@@ -168,15 +168,12 @@ public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, st
168168 }
169169
170170 public static Tensor cumsum < T > ( Tensor x , T axis = default , bool exclusive = false , bool reverse = false , string name = null )
171- {
172- return tf_with ( ops . name_scope ( name , "Cumsum" , new { x } ) , scope =>
173- {
174- name = scope ;
175- x = ops . convert_to_tensor ( x , name : "x" ) ;
176-
177- return gen_math_ops . cumsum ( x , axis : axis , exclusive : exclusive , reverse : reverse , name : name ) ;
178- } ) ;
179- }
171+ => tf_with ( ops . name_scope ( name , "Cumsum" , new { x } ) , scope =>
172+ {
173+ name = scope ;
174+ return tf . Context . ExecuteOp ( "Cumsum" , name , new ExecuteOpArgs ( x , axis )
175+ . SetAttributes ( new { exclusive , reverse } ) ) ;
176+ } ) ;
180177
181178 /// <summary>
182179 /// Computes Psi, the derivative of Lgamma (the log of the absolute value of
@@ -807,6 +804,31 @@ public static Tensor batch_matmul(Tensor x, Tensor y,
807804 . SetAttributes ( new { adj_x , adj_y } ) ) ;
808805 } ) ;
809806
807+ public static Tensor bincount ( Tensor arr , Tensor weights = null ,
808+ Tensor minlength = null ,
809+ Tensor maxlength = null ,
810+ TF_DataType dtype = TF_DataType . TF_INT32 ,
811+ string name = null ,
812+ TensorShape axis = null ,
813+ bool binary_output = false )
814+ => tf_with ( ops . name_scope ( name , "bincount" ) , scope =>
815+ {
816+ name = scope ;
817+ if ( ! binary_output && axis == null )
818+ {
819+ var array_is_nonempty = math_ops . reduce_prod ( array_ops . shape ( arr ) ) > 0 ;
820+ var output_size = math_ops . cast ( array_is_nonempty , dtypes . int32 ) * ( math_ops . reduce_max ( arr ) + 1 ) ;
821+ if ( minlength != null )
822+ output_size = math_ops . maximum ( minlength , output_size ) ;
823+ if ( maxlength != null )
824+ output_size = math_ops . minimum ( maxlength , output_size ) ;
825+ var weights = constant_op . constant ( new long [ 0 ] , dtype : dtype ) ;
826+ return tf . Context . ExecuteOp ( "Bincount" , name , new ExecuteOpArgs ( arr , output_size , weights ) ) ;
827+ }
828+
829+ throw new NotImplementedException ( "" ) ;
830+ } ) ;
831+
810832 /// <summary>
811833 /// Returns the complex conjugate of a complex number.
812834 /// </summary>
0 commit comments