Skip to content

Commit 2e89fa7

Browse files
committed
add strided_slice for Tensor, _SwitchRefOrTensor
1 parent 2ee6513 commit 2e89fa7

File tree

12 files changed

+245
-38
lines changed

12 files changed

+245
-38
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static Tensor expand_dims(Tensor input, int axis = -1, string name = null
2828
/// <param name="name"></param>
2929
/// <param name="conjugate"></param>
3030
/// <returns></returns>
31-
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
31+
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
3232
=> array_ops.transpose(a, perm, name, conjugate);
3333

3434
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)

src/TensorFlowNET.Core/Gradients/array_grad.py.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,21 @@ public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
1010
{
1111
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
1212
}
13+
14+
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
15+
{
16+
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
17+
}
18+
19+
private static Tensor _ReshapeToInput(Operation op, Tensor grad)
20+
{
21+
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
22+
}
23+
24+
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
25+
{
26+
var p = op.inputs[1];
27+
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
28+
}
1329
}
1430
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Gradients
6+
{
7+
public class control_flow_grad
8+
{
9+
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
10+
{
11+
var grad = grads[0];
12+
var _ = grads[1];
13+
var input_op = op.inputs[0].op;
14+
var graph = ops.get_default_graph();
15+
var op_ctxt = control_flow_util.GetOutputContext(input_op);
16+
var pred = op_ctxt.pred;
17+
18+
var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad");
19+
return new Tensor[] { results.Item1, results.Item2 };
20+
}
21+
}
22+
}

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
9393

9494
var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64),
9595
array_ops.size(in_shape) - 1);
96-
var outerdim = array_ops.shape(ind_2d);
96+
var outerdim = array_ops.shape(ind_2d)[0];
9797

9898
// Compute linear indices(flattened to 1D).
9999
var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64);
@@ -102,7 +102,17 @@ public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
102102
var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32);
103103
var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 });
104104

105-
throw new NotImplementedException("nn_grad._TopKGrad");
105+
// Substitute grad to appropriate locations and fill the rest with zeros,
106+
// finally reshaping it to the original input shape.
107+
var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1),
108+
array_ops.reshape(grad, new int[] { -1 }),
109+
new Tensor[] { math_ops.reduce_prod(in_shape) });
110+
111+
return new Tensor[]
112+
{
113+
array_ops.reshape(scatter, in_shape),
114+
array_ops.zeros(new int[0], dtype: TF_DataType.TF_INT32)
115+
};
106116
}
107117
}
108118
}

src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operatio
2626
return math_grad._IdGrad(oper, out_grads);
2727
case "MatMul":
2828
return math_grad._MatMulGrad(oper, out_grads);
29+
case "Merge":
30+
return control_flow_grad._MergeGrad(oper, out_grads);
2931
case "Mul":
3032
return math_grad._MulGrad(oper, out_grads);
3133
case "Mean":
@@ -42,8 +44,12 @@ public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operatio
4244
return array_grad._ReshapeGrad(oper, out_grads);
4345
case "Relu":
4446
return nn_grad._ReluGrad(oper, out_grads);
47+
case "Squeeze":
48+
return array_grad._SqueezeGrad(oper, out_grads);
4549
case "SoftmaxCrossEntropyWithLogits":
4650
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
51+
case "Transpose":
52+
return array_grad._TransposeGrad(oper, out_grads);
4753
case "TopK":
4854
case "TopKV2":
4955
return nn_grad._TopKGrad(oper, out_grads);

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,28 @@ namespace Tensorflow.Operations
1010
public class CondContext : ControlFlowContext
1111
{
1212
private string _name;
13+
1314
/// <summary>
1415
/// The boolean tensor for the cond predicate
1516
/// </summary>
1617
private Tensor _pred;
18+
public Tensor pred => _pred;
19+
1720
/// <summary>
1821
/// The predicate tensor in this branch
1922
/// </summary>
2023
private Tensor _pivot;
24+
2125
/// <summary>
2226
/// 0 or 1 representing this branch
2327
/// </summary>
2428
private int _branch;
29+
2530
/// <summary>
2631
///
2732
/// </summary>
2833
private List<string> _values = new List<string>();
34+
2935
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();
3036

3137
/// <summary>

src/TensorFlowNET.Core/Operations/Operation.Control.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,10 @@ public void _set_control_flow_context(CondContext ctx)
3232
{
3333
_control_flow_context = ctx;
3434
}
35+
36+
public CondContext _get_control_flow_context()
37+
{
38+
return _control_flow_context;
39+
}
3540
}
3641
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace Tensorflow
77
{
88
public class array_ops : Python
99
{
10-
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name);
10+
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null)
11+
=> gen_array_ops.placeholder_with_default(input, shape, name);
1112

1213
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
1314
{
@@ -111,14 +112,14 @@ public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dty
111112
});
112113
}
113114

114-
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name);
115+
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
116+
=> expand_dims_v2(input, axis, name);
115117

116-
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name);
118+
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null)
119+
=> gen_array_ops.expand_dims(input, axis, name);
117120

118121
public static Tensor rank(Tensor input, string name = null)
119-
{
120-
return math_ops.rank_internal(input, name, optimize: true);
121-
}
122+
=> math_ops.rank_internal(input, name, optimize: true);
122123

123124
/// <summary>
124125
/// Creates a tensor with all elements set to 1.
@@ -132,9 +133,7 @@ public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtIn
132133
=> ones_like_impl(tensor, dtype, name, optimize);
133134

134135
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
135-
{
136-
return gen_array_ops.reshape(tensor, shape, null);
137-
}
136+
=> gen_array_ops.reshape(tensor, shape, null);
138137

139138
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
140139
{
@@ -239,14 +238,10 @@ public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, s
239238
/// </param>
240239
/// <returns>A `Tensor` of type `out_type`.</returns>
241240
public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32)
242-
{
243-
return shape_internal(input, name, optimize: true, out_type: out_type);
244-
}
241+
=> shape_internal(input, name, optimize: true, out_type: out_type);
245242

246243
public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
247-
{
248-
return size_internal(input, name, optimize: optimize, out_type: out_type);
249-
}
244+
=> size_internal(input, name, optimize: optimize, out_type: out_type);
250245

251246
private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
252247
{
@@ -323,8 +318,46 @@ public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.D
323318
/// <param name="name"></param>
324319
/// <returns></returns>
325320
public static Tensor stop_gradient(Tensor input, string name = null)
321+
=> gen_array_ops.stop_gradient(input, name);
322+
323+
/// <summary>
324+
/// Extracts a strided slice of a tensor (generalized python array indexing).
325+
/// </summary>
326+
/// <param name="input_"></param>
327+
/// <param name="begin"></param>
328+
/// <param name="end"></param>
329+
/// <param name="strides"></param>
330+
/// <param name="begin_mask"></param>
331+
/// <param name="end_mask"></param>
332+
/// <param name="ellipsis_mask"></param>
333+
/// <param name="new_axis_mask"></param>
334+
/// <param name="shrink_axis_mask"></param>
335+
/// <param name="name"></param>
336+
/// <returns></returns>
337+
public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
338+
Tensor strides = null,
339+
int begin_mask = 0,
340+
int end_mask = 0,
341+
int ellipsis_mask = 0,
342+
int new_axis_mask = 0,
343+
int shrink_axis_mask = 0,
344+
string name = null)
326345
{
327-
return gen_array_ops.stop_gradient(input, name);
346+
var op = gen_array_ops.strided_slice(
347+
input: input_,
348+
begin: begin,
349+
end: end,
350+
strides: strides,
351+
begin_mask: begin_mask,
352+
end_mask: end_mask,
353+
ellipsis_mask: ellipsis_mask,
354+
new_axis_mask: new_axis_mask,
355+
shrink_axis_mask: shrink_axis_mask,
356+
name: name);
357+
358+
string parent_name = name;
359+
360+
return op;
328361
}
329362

330363
/// <summary>
@@ -345,14 +378,14 @@ public static Tensor stop_gradient(Tensor input, string name = null)
345378
/// Contains the same data as `input`, but has one or more dimensions of
346379
/// size 1 removed.</returns>
347380
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null)
348-
{
349-
return gen_array_ops.squeeze(input, axis, name);
350-
}
381+
=> gen_array_ops.squeeze(input, axis, name);
351382

352383
public static Tensor identity(Tensor input, string name = null)
353-
{
354-
return gen_array_ops.identity(input, name);
355-
}
384+
=> gen_array_ops.identity(input, name);
385+
386+
public static Tensor invert_permutation(Tensor x, string name = null)
387+
=> gen_array_ops.invert_permutation(x, name: name);
388+
356389
/// <summary>
357390
/// Computes the shape of a broadcast given symbolic shapes.
358391
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of
@@ -368,26 +401,19 @@ public static Tensor identity(Tensor input, string name = null)
368401
/// <param name="shape_y"> A rank 1 integer `Tensor`, representing the shape of y.</param>
369402
/// <returns> A rank 1 integer `Tensor` representing the broadcasted shape.</returns>
370403
public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y)
371-
{
372-
return gen_array_ops.broadcast_args(shape_x, shape_y);
373-
}
404+
=> gen_array_ops.broadcast_args(shape_x, shape_y);
374405

375406
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
376-
{
377-
return Framework.common_shapes.broadcast_shape(shape_x, shape_y);
378-
}
407+
=> Framework.common_shapes.broadcast_shape(shape_x, shape_y);
379408

380409
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
381-
{
382-
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
383-
}
410+
=> gen_array_ops.gather_v2(@params, indices, axis, name: name);
384411

385-
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false)
412+
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
386413
{
387414
return with(ops.name_scope(name, "transpose", new { a }), scope =>
388415
{
389-
name = scope;
390-
return gen_array_ops.transpose(a, perm, name);
416+
return gen_array_ops.transpose(a, perm, name: scope);
391417
});
392418
}
393419

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,22 @@ public static Tensor _Identity(Tensor data, string name = null)
138138
return gen_array_ops.identity(data, name: name);
139139
}
140140

141+
/// <summary>
142+
/// Forwards `data` to an output determined by `pred`.
143+
/// </summary>
144+
/// <param name="data"></param>
145+
/// <param name="pred"></param>
146+
/// <param name="name"></param>
147+
/// <returns></returns>
148+
public static (Tensor, Tensor) _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch")
149+
{
150+
data = ops.convert_to_tensor_or_indexed_slices(data, name: "data");
151+
152+
ops.colocate_with(data, ignore_existing: true);
153+
154+
return @switch(data, pred, name: name);
155+
}
156+
141157
public static Tensor[] cond<T>(Tensor pred,
142158
Func<T[]> true_fn = null,
143159
Func<T[]> false_fn = null,

src/TensorFlowNET.Core/Operations/control_flow_util.py.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Operations;
45

56
namespace Tensorflow
67
{
@@ -25,5 +26,12 @@ public static bool IsSwitch(Operation op)
2526
{
2627
return op.type == "Switch" || op.type == "RefSwitch";
2728
}
29+
30+
public static CondContext GetOutputContext(Operation op)
31+
{
32+
var ctxt = op._get_control_flow_context();
33+
34+
return ctxt;
35+
}
2836
}
2937
}

0 commit comments

Comments
 (0)