@@ -17,7 +17,9 @@ limitations under the License.
1717using NumSharp ;
1818using System ;
1919using System . Collections . Generic ;
20+ using System . Diagnostics ;
2021using System . Linq ;
22+ using static Tensorflow . Binding ;
2123
2224namespace Tensorflow
2325{
@@ -76,7 +78,14 @@ public Tensor check_numerics(Tensor tensor, string message, string name = null)
7678 public Tensor concat ( IList < Tensor > values , int axis , string name = "concat" )
7779 {
7880 if ( values . Count == 1 )
79- throw new NotImplementedException ( "tf.concat length is 1" ) ;
81+ {
82+ return tf_with ( ops . name_scope ( name ) , scope =>
83+ {
84+ var tensor = ops . convert_to_tensor ( axis , name : "concat_dim" , dtype : dtypes . int32 ) ;
85+ Debug . Assert ( tensor . TensorShape . ndim == 0 ) ;
86+ return identity ( values [ 0 ] , name : scope ) ;
87+ } ) ;
88+ }
8089
8190 return gen_array_ops . concat_v2 ( values . ToArray ( ) , axis , name : name ) ;
8291 }
@@ -111,7 +120,7 @@ public Tensor fill<T>(Tensor dims, T value, string name = null)
111120 /// <param name="input"></param>
112121 /// <param name="name"></param>
113122 /// <returns></returns>
114- public static Tensor identity ( Tensor input , string name = null )
123+ public Tensor identity ( Tensor input , string name = null )
115124 => array_ops . identity ( input , name : name ) ;
116125
117126 /// <summary>
@@ -150,10 +159,10 @@ public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose",
150159 /// <param name="axis"></param>
151160 /// <param name="name"></param>
152161 /// <returns></returns>
153- public static Tensor reverse ( Tensor tensor , int [ ] axis , string name = null )
162+ public Tensor reverse ( Tensor tensor , int [ ] axis , string name = null )
154163 => gen_array_ops . reverse ( tensor , axis , name : name ) ;
155164
156- public static Tensor reverse ( Tensor tensor , Tensor axis , string name = null )
165+ public Tensor reverse ( Tensor tensor , Tensor axis , string name = null )
157166 => gen_array_ops . reverse ( tensor , axis , name : name ) ;
158167
159168 /// <summary>
@@ -277,5 +286,14 @@ public Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name
277286 /// <returns>A `Tensor` with all elements set to zero.</returns>
278287 public Tensor zeros_like ( Tensor tensor , TF_DataType dtype = TF_DataType . DtInvalid , string name = null , bool optimize = true )
279288 => array_ops . zeros_like ( tensor , dtype : dtype , name : name , optimize : optimize ) ;
289+
290+ /// <summary>
291+ /// Stops gradient computation.
292+ /// </summary>
293+ /// <param name="x"></param>
294+ /// <param name="name"></param>
295+ /// <returns></returns>
296+ public Tensor stop_gradient ( Tensor x , string name = null )
297+ => gen_array_ops . stop_gradient ( x , name : name ) ;
280298 }
281299}
0 commit comments