@@ -124,6 +124,9 @@ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
124124 x , y ) . FirstOrDefault ( ) ,
125125 x , y ) ;
126126
127+ public static Tensor mean ( Tensor input , int axis , bool keep_dims = false , string name = null )
128+ => mean ( input , ops . convert_to_tensor ( axis ) , keep_dims : keep_dims , name : name ) ;
129+
127130 /// <summary>
128131 /// Computes the mean of elements across dimensions of a tensor.
129132 /// Reduces `input` along the dimensions given in `axis`. Unless
@@ -137,23 +140,30 @@ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
137140 /// <param name="keep_dims"> An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.</param>
138141 /// <param name="name"> A name for the operation (optional).</param>
139142 /// <returns> A `Tensor`. Has the same type as `input`.</returns>
140- public static Tensor mean < T1 , T2 > ( T1 input , T2 axis , bool keep_dims = false , string name = null )
141- {
142- if ( tf . Context . executing_eagerly ( ) )
143- {
144- var results = tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
143+ public static Tensor mean ( Tensor input , Tensor axis , bool keep_dims = false , string name = null )
144+ => tf . Context . RunInAutoMode2 (
145+ ( ) => tf . OpDefLib . _apply_op_helper ( "Mean" , name , new
146+ {
147+ input ,
148+ reduction_indices = axis ,
149+ keep_dims = keep_dims
150+ } ) . output ,
151+ ( ) => tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
145152 "Mean" , name ,
146153 null ,
147154 input , axis ,
148- "keep_dims" , keep_dims ) ;
149-
150- return results [ 0 ] ;
151- }
152-
153- var _op = tf . OpDefLib . _apply_op_helper ( "Mean" , name , args : new { input , reduction_indices = axis , keep_dims = keep_dims } ) ;
154-
155- return _op . output ;
156- }
155+ "keep_dims" , keep_dims ) . FirstOrDefault ( ) ,
156+ ( op ) =>
157+ {
158+ var attrs = new object [ ]
159+ {
160+ "T" , op . get_attr < TF_DataType > ( "T" ) ,
161+ "Tidx" , op . get_attr < TF_DataType > ( "Tidx" ) ,
162+ "keep_dims" , op . get_attr < bool > ( "keep_dims" )
163+ } ;
164+ tf . Runner . RecordGradient ( "Mean" , op . inputs , attrs , op . outputs ) ;
165+ } ,
166+ new Tensors ( input , axis ) ) ;
157167
158168 public static Tensor mean ( Tensor [ ] inputs , Tensor axis , bool keep_dims = false , string name = null )
159169 {
@@ -376,8 +386,18 @@ public static Tensor sinh(Tensor x, string name = null)
376386 return _op . outputs [ 0 ] ;
377387 }
378388
379- public static Tensor cos ( Tensor x , string name = null )
389+ public static Tensor cos < T > ( T x , string name = null )
380390 {
391+ if ( tf . executing_eagerly ( ) )
392+ {
393+ var results = tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
394+ "Cos" , name ,
395+ null ,
396+ x ) ;
397+
398+ return results [ 0 ] ;
399+ }
400+
381401 var _op = tf . OpDefLib . _apply_op_helper ( "Cos" , name , args : new { x } ) ;
382402
383403 return _op . outputs [ 0 ] ;
@@ -776,20 +796,21 @@ public static Tensor sqrt(Tensor x, string name = null)
776796 }
777797
778798 public static Tensor sub ( Tensor x , Tensor y , string name = null )
779- {
780- if ( tf . Context . executing_eagerly ( ) )
781- {
782- var results = tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
799+ => tf . Context . RunInAutoMode2 (
800+ ( ) => tf . OpDefLib . _apply_op_helper ( "Sub" , name , new { x , y } ) . output ,
801+ ( ) => tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
783802 "Sub" , name ,
784803 null ,
785- x , y ) ;
786- return results [ 0 ] ;
787- }
788-
789- var _op = tf . OpDefLib . _apply_op_helper ( "Sub" , name , args : new { x , y } ) ;
790-
791- return _op . output ;
792- }
804+ x , y ) . FirstOrDefault ( ) ,
805+ ( op ) =>
806+ {
807+ var attrs = new object [ ]
808+ {
809+ "T" , op . get_attr < TF_DataType > ( "T" )
810+ } ;
811+ tf . Runner . RecordGradient ( "Sub" , op . inputs , attrs , op . outputs ) ;
812+ } ,
813+ new Tensors ( x , y ) ) ;
793814
794815 public static Tensor sub < Tx , Ty > ( Tx x , Ty y , string name = null )
795816 {
0 commit comments