Skip to content

Commit 824dfe6

Browse files
committed
Pack/Unpack gradient. #847
1 parent f3102b9 commit 824dfe6

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,22 @@ public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
223223
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
224224
}
225225

226+
[RegisterGradient("Pack")]
227+
public static Tensor[] _PackGrad(Operation op, Tensor[] grads)
228+
{
229+
var grad = grads[0];
230+
var num = op.get_attr<int>("N");
231+
var axis = op.get_attr<int>("axis");
232+
return array_ops.unstack(grad, num: num, axis: axis);
233+
}
234+
235+
[RegisterGradient("Unpack")]
236+
public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads)
237+
{
238+
var axis = op.get_attr<int>("axis");
239+
return new[] { array_ops.stack(grads, axis: axis) };
240+
}
241+
226242
[RegisterGradient("Pad")]
227243
public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
228244
{

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -494,20 +494,12 @@ public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack")
494494
return ops.convert_to_tensor(values, name: name);
495495
}
496496

497-
var value_shape = ops.convert_to_tensor(values[0], name: name).shape;
498-
499497
return gen_array_ops.pack(values, axis: axis, name: name);
500498
}
501499

502500
public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack")
503501
{
504-
if (num == null)
505-
{
506-
value = ops.convert_to_tensor(value);
507-
var value_shape = value.shape;
508-
num = (int)value_shape.dims[axis];
509-
}
510-
502+
num = num ?? value.shape.as_int_list()[axis];
511503
return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name);
512504
}
513505

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,8 @@ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataTyp
265265
}
266266

267267
public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null)
268-
{
269-
var _op = tf.OpDefLib._apply_op_helper("Unpack", name, new { value, num, axis });
270-
return _op.outputs;
271-
}
268+
=> tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num)
269+
.SetAttributes(new { axis }));
272270

273271
public static Tensor where(Tensor condition, string name = null)
274272
{

0 commit comments

Comments
 (0)