Skip to content

Commit bad610d

Browse files
committed
_TopKGrad, _SoftmaxCrossEntropyWithLogitsGrad
1 parent cd77358 commit bad610d

File tree

10 files changed

+174
-18
lines changed

10 files changed

+174
-18
lines changed

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
131131
// for ops that do not have gradients.
132132
var grad_fn = ops.get_gradient_function(op);
133133

134+
foreach(var (i, out_grad) in enumerate(out_grads))
135+
{
136+
if(out_grad == null)
137+
{
138+
if (loop_state != null)
139+
;
140+
else
141+
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i);
142+
}
143+
}
144+
134145
with(ops.name_scope(op.name + "_grad"), scope1 =>
135146
{
136147
string name1 = scope1;
@@ -240,28 +251,27 @@ private static bool _IsPartitionedCall(Operation op)
240251
private static Tensor[] _AggregatedGrads(Dictionary<string, Tensor[][]> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0)
241252
{
242253
var out_grads = _GetGrads(grads, op);
243-
for(int i = 0; i < out_grads.Length; i++)
254+
var return_grads = new Tensor[out_grads.Length];
255+
256+
foreach(var (i, out_grad) in enumerate(out_grads))
244257
{
245-
var out_grad = out_grads[i];
246-
if(loop_state != null)
258+
if (loop_state != null)
247259
{
248260

249261
}
250262

251-
// Grads have to be Tensors or IndexedSlices
252-
253263
// Aggregate multiple gradients, and convert [] to None.
254-
if(out_grad != null)
264+
if (out_grad != null)
255265
{
256-
if(out_grad.Length < 2)
266+
if (out_grad.Length < 2)
257267
{
258268
string used = "nop";
259-
return new Tensor[] { out_grad[0] };
269+
return_grads[i] = out_grad[0];
260270
}
261271
}
262272
}
263273

264-
return null;
274+
return return_grads;
265275
}
266276

267277
/// <summary>

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Operations;
56

@@ -13,16 +14,17 @@ public class nn_grad
1314
/// <param name="op"></param>
1415
/// <param name="grad"></param>
1516
/// <returns></returns>
16-
public static Tensor[] _BiasAddGrad(Operation op, Tensor grad)
17+
public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
1718
{
19+
var grad = grads[0];
1820
string data_format = op.get_attr("data_format")?.ToString();
1921
var bias_add_grad = gen_nn_ops.bias_add_grad(out_backprop: grad, data_format: data_format);
2022
return new Tensor[] { grad, bias_add_grad };
2123
}
2224

23-
public static Tensor[] _ReluGrad(Operation op, Tensor grad)
25+
public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
2426
{
25-
return new Tensor[] { gen_nn_ops.relu_grad(grad, op.outputs[0]) };
27+
return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) };
2628
}
2729

2830
/// <summary>
@@ -37,8 +39,57 @@ public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[]
3739
var grad_loss = grads[0];
3840
var grad_grad = grads[1];
3941
var softmax_grad = op.outputs[1];
42+
var grad = _BroadcastMul(grad_loss, softmax_grad);
4043

41-
throw new NotImplementedException("_SoftmaxCrossEntropyWithLogitsGrad");
44+
var logits = op.inputs[0];
45+
if(grad_grad != null && !IsZero(grad_grad))
46+
{
47+
throw new NotImplementedException("_SoftmaxCrossEntropyWithLogitsGrad");
48+
}
49+
50+
return new Tensor[]
51+
{
52+
grad,
53+
_BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
54+
};
55+
}
56+
57+
private static bool IsZero(Tensor g)
58+
{
59+
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))
60+
return true;
61+
62+
throw new NotImplementedException("IsZero");
63+
}
64+
65+
private static Tensor _BroadcastMul(Tensor vec, Tensor mat)
66+
{
67+
vec = array_ops.expand_dims(vec, -1);
68+
return vec * mat;
69+
}
70+
71+
/// <summary>
72+
/// Return the gradients for TopK.
73+
/// </summary>
74+
/// <param name="op"></param>
75+
/// <param name="grads"></param>
76+
/// <returns></returns>
77+
public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
78+
{
79+
var grad = grads[0];
80+
var _ = grads[1];
81+
82+
var in_shape = array_ops.shape(op.inputs[0]);
83+
var ind_shape = array_ops.shape(op.outputs[1]);
84+
85+
// int32 is not supported on GPU hence up-casting
86+
var ind_lastdim = array_ops.gather(math_ops.cast(
87+
ind_shape, TF_DataType.TF_INT64), array_ops.size(ind_shape) - 1);
88+
89+
// Flatten indices to 2D.
90+
var ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack(new object[] { -1, ind_lastdim }));
91+
92+
throw new NotImplementedException("nn_grad._TopKGrad");
4293
}
4394
}
4495
}

src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operatio
1414
// map tensorflow\python\ops\math_grad.py
1515
return (oper, out_grads) =>
1616
{
17-
Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
17+
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
1818

1919
switch (oper.type)
2020
{
2121
case "Add":
2222
return math_grad._AddGrad(oper, out_grads);
23+
case "BiasAdd":
24+
return nn_grad._BiasAddGrad(oper, out_grads);
2325
case "Identity":
2426
return math_grad._IdGrad(oper, out_grads);
27+
case "MatMul":
28+
return math_grad._MatMulGrad(oper, out_grads);
2529
case "Mul":
2630
return math_grad._MulGrad(oper, out_grads);
2731
case "Mean":
@@ -36,8 +40,13 @@ public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operatio
3640
return math_grad._RealDivGrad(oper, out_grads);
3741
case "Reshape":
3842
return array_grad._ReshapeGrad(oper, out_grads);
43+
case "Relu":
44+
return nn_grad._ReluGrad(oper, out_grads);
3945
case "SoftmaxCrossEntropyWithLogits":
4046
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
47+
case "TopK":
48+
case "TopKV2":
49+
return nn_grad._TopKGrad(oper, out_grads);
4150
default:
4251
throw new NotImplementedException($"get_gradient_function {oper.type}");
4352
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ public static Tensor[] _fused_batch_norm(Tensor x,
9494
return _op.outputs;
9595
}
9696

97+
public static Tensor log_softmax(Tensor logits, string name = null)
98+
{
99+
var _op = _op_def_lib._apply_op_helper("LogSoftmax", name: name, args: new
100+
{
101+
logits
102+
});
103+
104+
return _op.outputs[0];
105+
}
106+
97107
public static Tensor max_pool(Tensor input,
98108
int[] ksize,
99109
int[] strides,

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ public object get_attr(string name)
185185
if (oneof_value == "type")
186186
return x.Type;
187187

188-
return x.GetType().GetProperty(oneof_value).GetValue(x);
188+
object result = x.GetType().GetProperty(oneof_value).GetValue(x);
189+
if (result is Google.Protobuf.ByteString byteString)
190+
return byteString.ToStringUtf8();
191+
return result;
189192
}
190193

191194
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dt
4646
}
4747
}
4848

49-
public static Tensor _autopacking_helper(Tensor[] list_or_tuple, TF_DataType dtype, string name)
49+
public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dtype, string name)
5050
{
5151
var must_pack = false;
52-
var converted_elems = new List<Tensor>();
52+
var converted_elems = new List<object>();
5353
return with(ops.name_scope(name), scope =>
5454
{
5555
foreach (var (i, elem) in enumerate(list_or_tuple))
@@ -58,7 +58,27 @@ public static Tensor _autopacking_helper(Tensor[] list_or_tuple, TF_DataType dty
5858
must_pack = true;
5959
}
6060

61-
return gen_array_ops.pack(converted_elems.ToArray(), name: scope);
61+
if(must_pack)
62+
{
63+
var elems_as_tensors = new List<Tensor>();
64+
foreach (var (i, elem) in enumerate(converted_elems))
65+
{
66+
if (elem is Tensor tensor)
67+
elems_as_tensors.Add(tensor);
68+
else
69+
{
70+
var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString());
71+
elems_as_tensors.Add(elem_tensor);
72+
}
73+
}
74+
75+
return gen_array_ops.pack(elems_as_tensors.ToArray(), name: scope);
76+
}
77+
else
78+
{
79+
// return converted_elems.ToArray();
80+
throw new NotImplementedException("_autopacking_helper.converted_elems");
81+
}
6282
});
6383
}
6484

@@ -355,5 +375,15 @@ public static Tensor transpose(Tensor a, int[] perm = null, string name = "trans
355375

356376
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
357377
=> gen_array_ops.slice(input, begin, size, name: name);
378+
379+
public static Tensor stack(object values, int axis = 0, string name = "stack")
380+
{
381+
if (axis == 0)
382+
// If the input is a constant list, it can be converted to a constant op
383+
return ops.convert_to_tensor(values, name: name);
384+
385+
throw new NotImplementedException("array_ops.stack");
386+
}
387+
358388
}
359389
}

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq;
44
using System.Text;
55
using Tensorflow.Operations;
6+
using util = Tensorflow.control_flow_util;
67

78
namespace Tensorflow
89
{
@@ -226,5 +227,18 @@ public static (Tensor, Tensor) @switch(Tensor data,
226227
return gen_control_flow_ops.@switch(data, pred, name: name);
227228
});
228229
}
230+
231+
public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
232+
{
233+
var val = op.outputs[index];
234+
if (!util.IsSwitch(op))
235+
{
236+
if (val.dtype == TF_DataType.TF_RESOURCE)
237+
throw new NotImplementedException("ZerosLikeOutsideLoop");
238+
return array_ops.zeros_like(val, optimize: false);
239+
}
240+
241+
throw new NotImplementedException("ZerosLikeOutsideLoop");
242+
}
229243
}
230244
}

src/TensorFlowNET.Core/Operations/control_flow_util.py.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,15 @@ public static bool IsLoopExit(Operation op)
1515
{
1616
return op.type == "Exit" || op.type == "RefExit";
1717
}
18+
19+
/// <summary>
20+
/// Return true if `op` is a Switch.
21+
/// </summary>
22+
/// <param name="op"></param>
23+
/// <returns></returns>
24+
public static bool IsSwitch(Operation op)
25+
{
26+
return op.type == "Switch" || op.type == "RefSwitch";
27+
}
1828
}
1929
}

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ public static Tensor bias_add(Tensor value,
4242
});
4343
}
4444

45+
public static Tensor log_softmax(Tensor logits, int axis = -1, string name = null)
46+
{
47+
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);
48+
}
49+
50+
public static Tensor _softmax(Tensor logits, Func<Tensor, string, Tensor> compute_op, int dim = -1, string name = null)
51+
{
52+
logits = ops.convert_to_tensor(logits);
53+
54+
var shape = logits.shape;
55+
bool is_last_dim = dim == -1 || dim == shape.Length - 1;
56+
if (is_last_dim)
57+
return compute_op(logits, name);
58+
59+
throw new NotImplementedException("_softmax helper");
60+
}
61+
4562
public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
4663
Tensor logits,
4764
int axis = -1,

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype
426426
return constant_op.constant(doubleVal, dtype: dtype, name: name);
427427
case RefVariable varVal:
428428
return varVal._TensorConversionFunction(as_ref: as_ref);
429+
case object[] objects:
430+
return array_ops._autopacking_helper(objects, dtype: dtype, name: name);
429431
default:
430432
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor");
431433
}

0 commit comments

Comments
 (0)