Skip to content

Commit cd77358

Browse files
committed
change input and output parameter as Tensor[] for ops.get_gradient_function
1 parent e80a681 commit cd77358

File tree

6 files changed

+102
-75
lines changed

6 files changed

+102
-75
lines changed

src/TensorFlowNET.Core/Gradients/array_grad.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ namespace Tensorflow.Gradients
66
{
77
public class array_grad
88
{
9-
public static (Tensor, Tensor) _ReshapeGrad(Operation op, Tensor grad)
9+
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
1010
{
11-
return (array_ops.reshape(grad, array_ops.shape(op.inputs[0])), null);
11+
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
1212
}
1313
}
1414
}

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
136136
string name1 = scope1;
137137
if (grad_fn != null)
138138
{
139-
in_grads = _MaybeCompile(grad_scope, op, out_grads[0], null, grad_fn);
139+
in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn);
140140
_VerifyGeneratedGradients(in_grads, op);
141141
}
142142

@@ -226,7 +226,7 @@ private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op)
226226
$"inputs {op.inputs._inputs.Count()}");
227227
}
228228

229-
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, Tensor[]> grad_fn)
229+
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
230230
{
231231
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
232232
return grad_fn(op, out_grads);

src/TensorFlowNET.Core/Gradients/math_grad.py.cs

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ namespace Tensorflow.Gradients
1010
/// </summary>
1111
public class math_grad
1212
{
13-
public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad)
13+
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
1414
{
1515
var x = op.inputs[0];
1616
var y = op.inputs[1];
17+
var grad = grads[0];
1718
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
18-
return (grad, grad);
19+
return new Tensor[] { grad, grad };
1920

2021
var sx = array_ops.shape(x);
2122
var sy = array_ops.shape(y);
@@ -24,21 +25,22 @@ public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad)
2425
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
2526
var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy);
2627

27-
return (r1, r2);
28+
return new Tensor[] { r1, r2 };
2829
}
2930

30-
public static Tensor _IdGrad(Operation op, Tensor grad)
31+
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
3132
{
32-
return grad;
33+
return new Tensor[] { grads[0] };
3334
}
3435

35-
public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad)
36+
public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
3637
{
3738
var x = op.inputs[0];
3839
var y = op.inputs[1];
40+
var grad = grads[0];
3941
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) &&
4042
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype))
41-
return (gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x));
43+
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) };
4244

4345
var sx = array_ops.shape(x);
4446
var sy = array_ops.shape(y);
@@ -54,11 +56,12 @@ public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad)
5456
var reshape1 = gen_array_ops.reshape(reduce_sum1, sx);
5557
var reshape2 = gen_array_ops.reshape(reduce_sum2, sy);
5658

57-
return (reshape1, reshape2);
59+
return new Tensor[] { reshape1, reshape2 };
5860
}
5961

60-
public static (Tensor, Tensor) _MatMulGrad(Operation op, Tensor grad)
62+
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
6163
{
64+
var grad = grads[0];
6265
Tensor grad_a = null, grad_b = null;
6366

6467
var t_a = (bool)op.get_attr("transpose_a");
@@ -86,33 +89,35 @@ public static (Tensor, Tensor) _MatMulGrad(Operation op, Tensor grad)
8689
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true, transpose_b: true);
8790
}
8891

89-
return (grad_a, grad_b);
92+
return new Tensor[] { grad_a, grad_b };
9093
}
9194

92-
public static (Tensor, Tensor) _MeanGrad(Operation op, Tensor grad)
95+
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
9396
{
94-
var sum_grad = _SumGrad(op, grad).Item1;
97+
var grad = grads[0];
98+
var sum_grad = _SumGrad(op, grads)[0];
9599
var input_shape = op.inputs[0]._shape_tuple();
96100
var output_shape = op.outputs[0]._shape_tuple();
97101

98102
var input_shape_tensor = array_ops.shape(op.inputs[0]);
99103
var output_shape_tensor = array_ops.shape(op.outputs[0]);
100104
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
101105

102-
return (math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null);
106+
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null };
103107
}
104108

105109
private static Tensor _safe_shape_div(Tensor x, Tensor y)
106110
{
107111
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
108112
}
109113

110-
public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
114+
public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
111115
{
116+
var grad = grads[0];
112117
var x = op.inputs[0];
113118
var y = op.inputs[1];
114119
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad))
115-
return (grad, -grad);
120+
return new Tensor[] { grad, -grad };
116121

117122
var sx = array_ops.shape(x);
118123
var sy = array_ops.shape(y);
@@ -121,16 +126,17 @@ public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
121126
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
122127
var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy);
123128

124-
return (r1, r2);
129+
return new Tensor[] { r1, r2 };
125130
}
126131

127132
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
128133
{
129134
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1;
130135
}
131136

132-
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
137+
public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
133138
{
139+
var grad = grads[0];
134140
var input_0_shape = op.inputs[0]._shape_tuple();
135141
Tensor input_shape = null;
136142

@@ -145,7 +151,7 @@ public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
145151
input_shape = constant_op.constant(input_0_shape);
146152
else
147153
input_shape = array_ops.shape(op.inputs[0]);
148-
return (gen_array_ops.tile(grad, input_shape), null);
154+
return new Tensor[] { gen_array_ops.tile(grad, input_shape), null };
149155
}
150156
}
151157

@@ -155,11 +161,12 @@ public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
155161
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
156162
grad = gen_array_ops.reshape(grad, output_shape_kept_dims);
157163

158-
return (gen_array_ops.tile(grad, tile_scaling), null);
164+
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
159165
}
160166

161-
public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
167+
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
162168
{
169+
var grad = grads[0];
163170
var x = op.inputs[0];
164171
var y = op.inputs[1];
165172

@@ -177,11 +184,12 @@ public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
177184
var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx);
178185
var reshape2 = gen_array_ops.reshape(reduce_sum2, sx);
179186

180-
return (reshape2, reshape1);
187+
return new Tensor[] { reshape2, reshape1 };
181188
}
182189

183-
public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
190+
public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
184191
{
192+
var grad = grads[0];
185193
var x = op.inputs[0];
186194
var y = op.inputs[1];
187195
var z = op.outputs[0];
@@ -212,7 +220,7 @@ public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
212220
var reduce_sum1 = math_ops.reduce_sum(mul1, ry);
213221
var gy = gen_array_ops.reshape(reduce_sum1, sy);
214222

215-
return (gx, gy);
223+
return new Tensor[] { gx, gy };
216224
}
217225
}
218226
}

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,32 @@ public class nn_grad
1313
/// <param name="op"></param>
1414
/// <param name="grad"></param>
1515
/// <returns></returns>
16-
public static (Tensor, Tensor) _BiasAddGrad(Operation op, Tensor grad)
16+
public static Tensor[] _BiasAddGrad(Operation op, Tensor grad)
1717
{
1818
string data_format = op.get_attr("data_format")?.ToString();
1919
var bias_add_grad = gen_nn_ops.bias_add_grad(out_backprop: grad, data_format: data_format);
20-
return (grad, bias_add_grad);
20+
return new Tensor[] { grad, bias_add_grad };
2121
}
2222

23-
public static (Tensor, Tensor) _ReluGrad(Operation op, Tensor grad)
23+
public static Tensor[] _ReluGrad(Operation op, Tensor grad)
2424
{
25-
return (gen_nn_ops.relu_grad(grad, op.outputs[0]), null);
25+
return new Tensor[] { gen_nn_ops.relu_grad(grad, op.outputs[0]) };
26+
}
27+
28+
/// <summary>
29+
/// Gradient function for SoftmaxCrossEntropyWithLogits.
30+
/// </summary>
31+
/// <param name="op"></param>
32+
/// <param name="grad_loss"></param>
33+
/// <param name="grad_grad"></param>
34+
/// <returns></returns>
35+
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
36+
{
37+
var grad_loss = grads[0];
38+
var grad_grad = grads[1];
39+
var softmax_grad = op.outputs[1];
40+
41+
throw new NotImplementedException("_SoftmaxCrossEntropyWithLogitsGrad");
2642
}
2743
}
2844
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Gradients;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class ops
9+
{
10+
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
11+
{
12+
if (op.inputs == null) return null;
13+
14+
// map tensorflow\python\ops\math_grad.py
15+
return (oper, out_grads) =>
16+
{
17+
Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
18+
19+
switch (oper.type)
20+
{
21+
case "Add":
22+
return math_grad._AddGrad(oper, out_grads);
23+
case "Identity":
24+
return math_grad._IdGrad(oper, out_grads);
25+
case "Mul":
26+
return math_grad._MulGrad(oper, out_grads);
27+
case "Mean":
28+
return math_grad._MeanGrad(oper, out_grads);
29+
case "Sum":
30+
return math_grad._SumGrad(oper, out_grads);
31+
case "Sub":
32+
return math_grad._SubGrad(oper, out_grads);
33+
case "Pow":
34+
return math_grad._PowGrad(oper, out_grads);
35+
case "RealDiv":
36+
return math_grad._RealDivGrad(oper, out_grads);
37+
case "Reshape":
38+
return array_grad._ReshapeGrad(oper, out_grads);
39+
case "SoftmaxCrossEntropyWithLogits":
40+
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
41+
default:
42+
throw new NotImplementedException($"get_gradient_function {oper.type}");
43+
}
44+
};
45+
}
46+
}
47+
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -346,50 +346,6 @@ public static void _run_using_default_session(Operation operation, FeedItem[] fe
346346
session.run(operation, feed_dict);
347347
}
348348

349-
public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op)
350-
{
351-
if (op.inputs == null) return null;
352-
353-
// map tensorflow\python\ops\math_grad.py
354-
return (oper, out_grads) =>
355-
{
356-
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
357-
358-
switch (oper.type)
359-
{
360-
case "Add":
361-
var add = math_grad._AddGrad(oper, out_grads);
362-
return new Tensor[] { add.Item1, add.Item2 };
363-
case "Identity":
364-
var id = math_grad._IdGrad(oper, out_grads);
365-
return new Tensor[] { id };
366-
case "Mul":
367-
var mul = math_grad._MulGrad(oper, out_grads);
368-
return new Tensor[] { mul.Item1, mul.Item2 };
369-
case "Mean":
370-
var mean = math_grad._MeanGrad(oper, out_grads);
371-
return new Tensor[] { mean.Item1, mean.Item2 };
372-
case "Sum":
373-
var sum = math_grad._SumGrad(oper, out_grads);
374-
return new Tensor[] { sum.Item1, sum.Item2 };
375-
case "Sub":
376-
var sub = math_grad._SubGrad(oper, out_grads);
377-
return new Tensor[] { sub.Item1, sub.Item2 };
378-
case "Pow":
379-
var pow = math_grad._PowGrad(oper, out_grads);
380-
return new Tensor[] { pow.Item1, pow.Item2 };
381-
case "RealDiv":
382-
var realdiv = math_grad._RealDivGrad(oper, out_grads);
383-
return new Tensor[] { realdiv.Item1, realdiv.Item2 };
384-
case "Reshape":
385-
var reshape = array_grad._ReshapeGrad(oper, out_grads);
386-
return new Tensor[] { reshape.Item1, reshape.Item2 };
387-
default:
388-
throw new NotImplementedException($"get_gradient_function {oper.type}");
389-
}
390-
};
391-
}
392-
393349
public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
394350
{
395351
return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);

0 commit comments

Comments
 (0)