Skip to content

Commit bee3a10

Browse files
committed
ones_like fix
1 parent 17a4fe0 commit bee3a10

File tree

4 files changed

+43
-23
lines changed

4 files changed

+43
-23
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
138138
[RegisterNoGradient("GreaterEqual")]
139139
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;
140140

141+
[RegisterNoGradient("OnesLike")]
142+
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;
143+
141144
[RegisterNoGradient("ZerosLike")]
142145
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;
143146

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
274274
{
275275
if (elem is EagerTensor eager_tensor)
276276
{
277-
if(switch_to_graph)
277+
if (switch_to_graph)
278278
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
279279
else
280280
elems_as_tensors.Add(eager_tensor);
@@ -366,8 +366,30 @@ public static Tensor rank_internal(Tensor input, string name = null, bool optimi
366366
/// <param name="name"></param>
367367
/// <param name="optimize"></param>
368368
/// <returns></returns>
369-
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
370-
=> ones_like_impl(tensor, dtype, name, optimize);
369+
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
370+
{
371+
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
372+
{
373+
name = scope;
374+
tensor = ops.convert_to_tensor(tensor, name: "tensor");
375+
376+
// is_fully_defined return unexpected value.
377+
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
378+
{
379+
380+
}
381+
382+
if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
383+
{
384+
throw new NotImplementedException("ones_like");
385+
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
386+
}
387+
else
388+
{
389+
return gen_array_ops.ones_like(tensor, name: name);
390+
}
391+
});
392+
}
371393

372394
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
373395
=> gen_array_ops.reshape(tensor, shape, name: name);
@@ -378,21 +400,6 @@ public static Tensor reshape(Tensor tensor, TensorShape shape, string name = nul
378400
public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
379401
=> gen_array_ops.reshape(tensor, shape, name: name);
380402

381-
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
382-
{
383-
return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
384-
{
385-
name = scope;
386-
var tensor1 = ops.convert_to_tensor(tensor, name: "tensor");
387-
var ones_shape = shape_internal(tensor1, optimize: optimize);
388-
if (dtype == TF_DataType.DtInvalid)
389-
dtype = tensor1.dtype;
390-
var ret = ones(ones_shape, dtype: dtype, name: name);
391-
ret.shape = tensor1.shape;
392-
return ret;
393-
});
394-
}
395-
396403
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
397404
{
398405
dtype = dtype.as_base_dtype();
@@ -891,7 +898,7 @@ public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transp
891898
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
892899
{
893900
var a_tensor = ops.convert_to_tensor(a);
894-
if(perm == null)
901+
if (perm == null)
895902
{
896903
var rank = a_tensor.rank;
897904
perm = range(0, rank).OrderByDescending(x => x).ToArray();
@@ -953,7 +960,9 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
953960
=> tf.Context.RunInAutoMode2(
954961
() => tf.OpDefLib._apply_op_helper("Slice", name, new
955962
{
956-
input, begin, size
963+
input,
964+
begin,
965+
size
957966
}).output,
958967
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
959968
"Slice", name,
@@ -969,8 +978,8 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
969978
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
970979
},
971980
new Tensors(input, begin, size));
972-
973-
public static Tensor stack(object values, int axis = 0, string name = "stack")
981+
982+
public static Tensor stack(object values, int axis = 0, string name = "stack")
974983
{
975984
if (axis == 0)
976985
// If the input is a constant list, it can be converted to a constant op

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,15 @@ public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
591591
return _op.outputs[0];
592592
}
593593

594+
public static Tensor ones_like(Tensor x, string name = null)
595+
=> tf.Context.RunInAutoMode(()
596+
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
597+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
598+
"OnesLike", name,
599+
null,
600+
x).FirstOrDefault(),
601+
x);
602+
594603
public static Tensor zeros_like(Tensor x, string name = null)
595604
=> tf.Context.RunInAutoMode(()
596605
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()

test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ public void ConcatAndSplitTest()
132132
}
133133

134134
#region ones/zeros like
135-
[Ignore]
136135
[TestMethod]
137136
public void TestOnesLike()
138137
{

0 commit comments

Comments
 (0)