Skip to content

Commit 47f953b

Browse files
authored
Merge pull request #736 from MPnoy/ones_like-fix
ones_like fix
2 parents 26a04bd + a11cb71 commit 47f953b

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
138138
[RegisterNoGradient("GreaterEqual")]
139139
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;
140140

141+
[RegisterNoGradient("OnesLike")]
142+
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;
143+
141144
[RegisterNoGradient("ZerosLike")]
142145
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;
143146

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
274274
{
275275
if (elem is EagerTensor eager_tensor)
276276
{
277-
if(switch_to_graph)
277+
if (switch_to_graph)
278278
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
279279
else
280280
elems_as_tensors.Add(eager_tensor);
@@ -366,8 +366,30 @@ public static Tensor rank_internal(Tensor input, string name = null, bool optimi
366366
/// <param name="name"></param>
367367
/// <param name="optimize"></param>
368368
/// <returns></returns>
369-
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
370-
=> ones_like_impl(tensor, dtype, name, optimize);
369+
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
370+
{
371+
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
372+
{
373+
name = scope;
374+
tensor = ops.convert_to_tensor(tensor, name: "tensor");
375+
376+
// is_fully_defined return unexpected value.
377+
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
378+
{
379+
380+
}
381+
382+
if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
383+
{
384+
throw new NotImplementedException("ones_like");
385+
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
386+
}
387+
else
388+
{
389+
return gen_array_ops.ones_like(tensor, name: name);
390+
}
391+
});
392+
}
371393

372394
public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
373395
=> gen_array_ops.reshape(tensor, shape, name: name);
@@ -888,7 +910,7 @@ public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transp
888910
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
889911
{
890912
var a_tensor = ops.convert_to_tensor(a);
891-
if(perm == null)
913+
if (perm == null)
892914
{
893915
var rank = a_tensor.rank;
894916
perm = range(0, rank).OrderByDescending(x => x).ToArray();
@@ -950,7 +972,9 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
950972
=> tf.Context.RunInAutoMode2(
951973
() => tf.OpDefLib._apply_op_helper("Slice", name, new
952974
{
953-
input, begin, size
975+
input,
976+
begin,
977+
size
954978
}).output,
955979
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
956980
"Slice", name,
@@ -966,8 +990,8 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
966990
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
967991
},
968992
new Tensors(input, begin, size));
969-
970-
public static Tensor stack(object values, int axis = 0, string name = "stack")
993+
994+
public static Tensor stack(object values, int axis = 0, string name = "stack")
971995
{
972996
if (axis == 0)
973997
// If the input is a constant list, it can be converted to a constant op

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,15 @@ public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
591591
return _op.outputs[0];
592592
}
593593

594+
public static Tensor ones_like(Tensor x, string name = null)
595+
=> tf.Context.RunInAutoMode(()
596+
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
597+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
598+
"OnesLike", name,
599+
null,
600+
x).FirstOrDefault(),
601+
x);
602+
594603
public static Tensor zeros_like(Tensor x, string name = null)
595604
=> tf.Context.RunInAutoMode(()
596605
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()

0 commit comments

Comments
 (0)