Skip to content

Commit 14c26e7

Browse files
committed
Mul gradient is not correct in TensorFlowOpLayer #698
1 parent 6509ae0 commit 14c26e7

File tree

10 files changed

+53
-39
lines changed

10 files changed

+53
-39
lines changed

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
380380
c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value));
381381
break;
382382
case TF_AttrType.TF_ATTR_INT:
383-
c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value));
383+
attr_list_sizes[key] = Convert.ToInt64(value);
384+
c_api.TFE_OpSetAttrInt(op, key, attr_list_sizes[key]);
384385
break;
385386
case TF_AttrType.TF_ATTR_FLOAT:
386387
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,24 @@ public EagerDefinedFunction Forward(Tensors inference_args)
4444
public void Record(Tensors flat_outputs, Tensors inference_args)
4545
{
4646
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
47-
tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args,
47+
tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
4848
getBackwardFunction: () => backward_function);
4949
}
5050

5151
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs)
5252
{
5353
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
5454
{
55-
return new Tensor[0];
56-
57-
/*var gradients = ops.gradientFunctions[op_name](new EagerOperation
55+
var processed_args = new List<Tensor>();
56+
var input_index = 0;
57+
foreach (var (output_index, arg) in enumerate(output_grads))
5858
{
59-
Name = op_name,
60-
NumInputs = op_inputs.Length,
61-
Inputs = op_inputs,
62-
NumOutputs = op_outputs.Length,
63-
Outputs = op_outputs,
64-
SkipInputIndices = unneeded_gradients,
65-
Attrs = attrs
66-
}, output_grads);
67-
68-
return gradients;*/
59+
if (arg is null)
60+
throw new NotImplementedException("");
61+
processed_args.add(arg);
62+
input_index += 1;
63+
}
64+
return output_grads;// backward.Invoke(processed_args.ToArray());
6965
};
7066

7167
return (_backward_function_wrapper, flat_outputs);

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_v
8585
var out_grads = new List<Tensor>();
8686
if(concat_dim is EagerTensor)
8787
{
88-
var non_neg_concat_dim = (int)concat_dim % input_values[0].rank;
88+
var dim_int = (int)concat_dim;
89+
var non_neg_concat_dim = dim_int < 0
90+
? input_values[0].rank + dim_int
91+
: dim_int % input_values[0].rank;
8992
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
9093
var sizes_tensor = constant_op.constant(sizes);
91-
out_grads = gen_array_ops.split_v(grad, sizes_tensor, sizes[0], non_neg_concat_dim).ToList();
94+
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
9295
}
9396
else if (constant_op.is_constant(concat_dim))
9497
{

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
212212
};
213213
}
214214

215-
var broads = SmartBroadcastGradientArgs(x, y);
215+
var broads = SmartBroadcastGradientArgs(x, y, grad);
216216
var (sx, rx, must_reduce_x) = broads[0];
217217
var (sy, ry, must_reduce_y) = broads[1];
218218

@@ -468,7 +468,7 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
468468
_ShapesFullySpecifiedAndEqual(x, y, grad))
469469
return new Tensor[] { grad, -grad };
470470

471-
var broads = SmartBroadcastGradientArgs(x, y);
471+
var broads = SmartBroadcastGradientArgs(x, y, grad);
472472
var (sx, rx, must_reduce_x) = broads[0];
473473
var (sy, ry, must_reduce_y) = broads[1];
474474

@@ -718,7 +718,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
718718

719719
var z = op.outputs[0];
720720

721-
var broads = SmartBroadcastGradientArgs(x, y);
721+
var broads = SmartBroadcastGradientArgs(x, y, grad);
722722
var (sx, rx, must_reduce_x) = broads[0];
723723
var (sy, ry, must_reduce_y) = broads[1];
724724

@@ -753,7 +753,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
753753
/// <param name="x"></param>
754754
/// <param name="y"></param>
755755
/// <returns></returns>
756-
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y)
756+
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
757757
{
758758
Tensor sx, sy;
759759
if (x.TensorShape.is_fully_defined() &&
@@ -771,8 +771,8 @@ private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Ten
771771
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
772772
return new[]
773773
{
774-
(sx, rx, true),
775-
(sy, ry, true)
774+
(sx, rx, !x.TensorShape.Equals(grad.TensorShape)),
775+
(sy, ry, !y.TensorShape.Equals(grad.TensorShape))
776776
};
777777
}
778778
}

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,15 @@ public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose",
885885
});
886886
}
887887

888+
public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
889+
string name = "split")
890+
{
891+
if (num == -1)
892+
num = size_splits.shape[0];
893+
894+
return gen_array_ops.split_v(value, size_splits, axis, num, name: name);
895+
}
896+
888897
public static Tensor[] split<T>(Tensor value, int num_split, T axis,
889898
string name = "split")
890899
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ public static Tensor[] split_v(Tensor value, Tensor size_splits,
527527
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
528528
"SplitV", name,
529529
null,
530-
value, size_splits, axis,
530+
value, size_splits, axis,
531531
"num_split", num_split);
532532

533533
return results;

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -346,21 +346,21 @@ public static Tensor sigmoid(Tensor x, string name = "Sigmoid")
346346
/// <c>dy</c> is the corresponding input gradient.
347347
/// </remarks>
348348
public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad")
349-
{
350-
if (tf.executing_eagerly())
351-
{
352-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
349+
=> tf.Context.RunInAutoMode2(
350+
() => tf.OpDefLib._apply_op_helper("SigmoidGrad", name, new { y, dy }).output,
351+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
353352
"SigmoidGrad", name,
354353
null,
355-
y, dy);
356-
357-
return results[0];
358-
}
359-
360-
var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy });
361-
362-
return op.output;
363-
}
354+
y, dy).FirstOrDefault(),
355+
(op) =>
356+
{
357+
var attrs = new object[]
358+
{
359+
"T", op.get_attr<TF_DataType>("T")
360+
};
361+
tf.Runner.RecordGradient("SigmoidGrad", op.inputs, attrs, op.outputs);
362+
},
363+
new Tensors(y, dy));
364364

365365
public static Tensor sign<T>(T x, string name = "Sign")
366366
{

src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ public override bool Equals(Object obj)
1010
switch (obj)
1111
{
1212
case TensorShape shape1:
13+
if (rank == -1 && shape1.rank == -1)
14+
return false;
15+
else if (rank != shape1.rank)
16+
return false;
1317
return Enumerable.SequenceEqual(shape1.dims, dims);
1418
default:
1519
return false;

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ public Tensor concatenate(Tensors tensors, int axis = -1)
239239
{
240240
var rank = tensors[0].NDims;
241241
if (rank > -1)
242-
axis %= rank;
242+
axis += rank;
243243
else
244244
axis = 0;
245245
}

src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ public Reshape(ReshapeArgs args)
2121

2222
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
2323
{
24-
var shape = new List<int> { inputs.shape[0] };
24+
var shape_tensor = array_ops.shape(inputs);
25+
var shape = new List<int> { shape_tensor.shape[0] };
2526
shape.AddRange(args.TargetShape.dims);
2627

2728
var result = array_ops.reshape(inputs, shape.ToArray());

0 commit comments

Comments
 (0)