@@ -274,7 +274,7 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
274274 {
275275 if ( elem is EagerTensor eager_tensor )
276276 {
277- if ( switch_to_graph )
277+ if ( switch_to_graph )
278278 elems_as_tensors . Add ( constant_op . constant ( eager_tensor . numpy ( ) , dtype : dtype , name : i . ToString ( ) ) ) ;
279279 else
280280 elems_as_tensors . Add ( eager_tensor ) ;
@@ -366,8 +366,30 @@ public static Tensor rank_internal(Tensor input, string name = null, bool optimi
366366 /// <param name="name"></param>
367367 /// <param name="optimize"></param>
368368 /// <returns></returns>
369- public static Tensor ones_like < T > ( T tensor , TF_DataType dtype = TF_DataType . DtInvalid , string name = null , bool optimize = true )
370- => ones_like_impl ( tensor , dtype , name , optimize ) ;
369+ public static Tensor ones_like ( Tensor tensor , TF_DataType dtype = TF_DataType . DtInvalid , string name = null , bool optimize = true )
370+ {
371+ return tf_with ( ops . name_scope ( name , "ones_like" , new Tensor [ ] { tensor } ) , scope =>
372+ {
373+ name = scope ;
374+ tensor = ops . convert_to_tensor ( tensor , name : "tensor" ) ;
375+
376+ // is_fully_defined return unexpected value.
377+ if ( optimize && tensor_util . to_shape ( tensor . shape ) . is_fully_defined ( ) && dtype != TF_DataType . TF_VARIANT )
378+ {
379+
380+ }
381+
382+ if ( dtype != TF_DataType . DtInvalid && dtype != tensor . dtype && dtype != TF_DataType . TF_VARIANT )
383+ {
384+ throw new NotImplementedException ( "ones_like" ) ;
385+ // return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
386+ }
387+ else
388+ {
389+ return gen_array_ops . ones_like ( tensor , name : name ) ;
390+ }
391+ } ) ;
392+ }
371393
372394 public static Tensor reshape ( Tensor tensor , Tensor shape , string name = null )
373395 => gen_array_ops . reshape ( tensor , shape , name : name ) ;
@@ -378,21 +400,6 @@ public static Tensor reshape(Tensor tensor, TensorShape shape, string name = nul
378400 public static Tensor reshape ( Tensor tensor , object [ ] shape , string name = null )
379401 => gen_array_ops . reshape ( tensor , shape , name : name ) ;
380402
381- private static Tensor ones_like_impl < T > ( T tensor , TF_DataType dtype , string name , bool optimize = true )
382- {
383- return tf_with ( ops . name_scope ( name , "ones_like" , new { tensor } ) , scope =>
384- {
385- name = scope ;
386- var tensor1 = ops . convert_to_tensor ( tensor , name : "tensor" ) ;
387- var ones_shape = shape_internal ( tensor1 , optimize : optimize ) ;
388- if ( dtype == TF_DataType . DtInvalid )
389- dtype = tensor1 . dtype ;
390- var ret = ones ( ones_shape , dtype : dtype , name : name ) ;
391- ret . shape = tensor1 . shape ;
392- return ret ;
393- } ) ;
394- }
395-
396403 public static Tensor ones ( Tensor shape , TF_DataType dtype = TF_DataType . TF_FLOAT , string name = null )
397404 {
398405 dtype = dtype . as_base_dtype ( ) ;
@@ -891,7 +898,7 @@ public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transp
891898 return tf_with ( ops . name_scope ( name , "transpose" , new { a } ) , scope =>
892899 {
893900 var a_tensor = ops . convert_to_tensor ( a ) ;
894- if ( perm == null )
901+ if ( perm == null )
895902 {
896903 var rank = a_tensor . rank ;
897904 perm = range ( 0 , rank ) . OrderByDescending ( x => x ) . ToArray ( ) ;
@@ -953,7 +960,9 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
953960 => tf . Context . RunInAutoMode2 (
954961 ( ) => tf . OpDefLib . _apply_op_helper ( "Slice" , name , new
955962 {
956- input , begin , size
963+ input ,
964+ begin ,
965+ size
957966 } ) . output ,
958967 ( ) => tf . Runner . TFE_FastPathExecute ( tf . Context , tf . Context . DeviceName ,
959968 "Slice" , name ,
@@ -969,8 +978,8 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
969978 tf . Runner . RecordGradient ( "Slice" , op . inputs , attrs , op . outputs ) ;
970979 } ,
971980 new Tensors ( input , begin , size ) ) ;
972-
973- public static Tensor stack ( object values , int axis = 0 , string name = "stack" )
981+
982+ public static Tensor stack ( object values , int axis = 0 , string name = "stack" )
974983 {
975984 if ( axis == 0 )
976985 // If the input is a constant list, it can be converted to a constant op
0 commit comments