Skip to content

Commit a8a5156

Browse files
committed
_SwitchGrad
1 parent c8a61b2 commit a8a5156

File tree

1 file changed

+74
-67
lines changed

1 file changed

+74
-67
lines changed

src/TensorFlowNET.Core/Gradients/control_flow_grad.cs

Lines changed: 74 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,19 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
4545
switch (op_ctxt)
4646
{
4747
case WhileContext cwhile:
48-
throw new NotImplementedException("_SwitchGrad WhileContext");
48+
{
49+
var merge_grad = grad_ctxt.grad_state.switch_map.get(op);
50+
if (merge_grad != null)
51+
throw new NotImplementedException("_SwitchGrad merge_grad != null");
52+
else if (grads[0] != null)
53+
{
54+
merge_grad = merge(new[] { grads[0], grads[0] }, name: "b_switch")[0];
55+
grad_ctxt.grad_state.switch_map[op] = merge_grad;
56+
return new Tensor[] { merge_grad, null };
57+
}
58+
else
59+
return new Tensor[] { null, null };
60+
}
4961
case CondContext ccond:
5062
{
5163
var zero_grad = grads[1 - op_ctxt.branch];
@@ -74,7 +86,7 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
7486
/// <param name="inputs"></param>
7587
/// <param name="name"></param>
7688
/// <returns></returns>
77-
internal static Tensor[] merge(Tensor[] inputs, string name = null)
89+
internal static MergeOutput merge(Tensor[] inputs, string name = null)
7890
{
7991
return tf_with(ops.name_scope(name, "Merge", inputs), scope =>
8092
{
@@ -146,7 +158,7 @@ public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
146158
}
147159

148160
[RegisterGradient("RefMerge")]
149-
public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
161+
public static Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
150162
{
151163
return _MergeGrad(op, grads);
152164
}
@@ -155,43 +167,32 @@ public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
155167
/// Gradients for an exit op are calculated using an Enter op.
156168
/// </summary>
157169
[RegisterGradient("Exit")]
158-
public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
170+
public static Tensor[] _ExitGrad(Operation op, Tensor[] grads)
159171
{
160-
throw new NotImplementedException("_ExitGrad");
161-
// graph = ops.get_default_graph()
162-
//# pylint: disable=protected-access
163-
// op_ctxt = op._get_control_flow_context()
164-
// grad_ctxt = graph._get_control_flow_context()
165-
// # pylint: enable=protected-access
166-
// if not grad_ctxt.back_prop:
167-
// # The flag `back_prop` is set by users to suppress gradient
168-
// # computation for this loop. If the attribute `back_prop` is false,
169-
// # no gradient computation.
170-
// return None
172+
var grad = grads[0];
173+
var graph = ops.get_default_graph();
174+
var op_ctxt = op._get_control_flow_context();
175+
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
176+
// The flag `back_prop` is set by users to suppress gradient
177+
// computation for this loop. If the attribute `back_prop` is false,
178+
// no gradient computation.
179+
if (!grad_ctxt.back_prop)
180+
return null;
181+
182+
if (op_ctxt.grad_state != null)
183+
throw new TypeError("Second-order gradient for while loops not supported.");
184+
185+
grad_ctxt.AddName(grad.name);
171186

172-
// if op_ctxt.grad_state:
173-
// raise TypeError("Second-order gradient for while loops not supported.")
187+
grad_ctxt.Enter();
188+
var result = control_flow_ops._Enter(
189+
grad, grad_ctxt.name, is_constant: false,
190+
parallel_iterations: grad_ctxt.parallel_iterations,
191+
name: "b_exit");
174192

175-
// if isinstance(grad, ops.Tensor) :
176-
// grad_ctxt.AddName(grad.name)
177-
// else:
178-
// if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
179-
// raise TypeError("Type %s not supported" % type(grad))
180-
// grad_ctxt.AddName(grad.values.name)
181-
// grad_ctxt.AddName(grad.indices.name)
182-
// dense_shape = grad.dense_shape
183-
// if dense_shape is not None:
184-
// grad_ctxt.AddName(dense_shape.name)
185-
// grad_ctxt.Enter()
186-
// # pylint: disable=protected-access
187-
// result = control_flow_ops._Enter(
188-
// grad, grad_ctxt.name, is_constant=False,
189-
// parallel_iterations=grad_ctxt.parallel_iterations,
190-
// name="b_exit")
191-
// # pylint: enable=protected-access
192-
// grad_ctxt.loop_enters.append(result)
193-
// grad_ctxt.Exit()
194-
// return result
193+
grad_ctxt.loop_enters.append(result);
194+
grad_ctxt.Exit();
195+
return new[] { result };
195196
}
196197

197198
/// <summary>
@@ -200,15 +201,15 @@ public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
200201
/// Note that the backprop next_iteration is added in switch grad.
201202
/// </summary>
202203
[RegisterGradient("NextIteration")]
203-
public Tensor[] _NextIterationGrad(object _, Tensor[] grad)
204+
public static Tensor[] _NextIterationGrad(Operation op, Tensor[] grads)
204205
{
205-
return grad;
206+
return grads;
206207
}
207208

208209
[RegisterGradient("RefNextIteration")]
209-
public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
210+
public static Tensor[] _RefNextIterationGrad(Operation op, Tensor[] grads)
210211
{
211-
return grad;
212+
return grads;
212213
}
213214

214215
/// <summary>
@@ -218,33 +219,39 @@ public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
218219
/// For loop invariants, we need to add an accumulator loop.
219220
/// </summary>
220221
[RegisterGradient("Enter")]
221-
public Tensor[] _EnterGrad(Tensor op, Tensor[] grad)
222+
public static Tensor[] _EnterGrad(Operation op, Tensor[] grads)
222223
{
223-
throw new NotImplementedException("_EnterGrad");
224-
// graph = ops.get_default_graph()
225-
//# pylint: disable=protected-access
226-
// grad_ctxt = graph._get_control_flow_context()
227-
// # pylint: enable=protected-access
228-
// if not grad_ctxt.back_prop:
229-
// # Skip gradient computation, if the attribute `back_prop` is false.
230-
// return grad
231-
// if grad_ctxt.grad_state is None:
232-
// # Pass the gradient through if we are not in a gradient while context.
233-
// return grad
234-
// if op.get_attr("is_constant"):
235-
// # Add a gradient accumulator for each loop invariant.
236-
// if isinstance(grad, ops.Tensor) :
237-
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
238-
// elif isinstance(grad, ops.IndexedSlices) :
239-
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
240-
// else:
241-
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
242-
// raise TypeError("Type %s not supported" % type(grad))
243-
// else:
244-
// result = exit(grad)
245-
// grad_ctxt.loop_exits.append(result)
246-
// grad_ctxt.ExitResult([result])
247-
// return result
224+
Tensor result = null;
225+
var grad = grads[0];
226+
var graph = ops.get_default_graph();
227+
var grad_ctxt = graph._get_control_flow_context() as WhileContext;
228+
if (!grad_ctxt.back_prop)
229+
// Skip gradient computation, if the attribute `back_prop` is false.
230+
return grads;
231+
if (grad_ctxt.grad_state == null)
232+
// Pass the gradient through if we are not in a gradient while context.
233+
return grads;
234+
if (op.get_attr<bool>("is_constant"))
235+
{
236+
throw new NotImplementedException("_EnterGrad is_constant");
237+
// Add a gradient accumulator for each loop invariant.
238+
// if isinstance(grad, ops.Tensor) :
239+
// result = grad_ctxt.AddBackpropAccumulator(op, grad)
240+
// elif isinstance(grad, ops.IndexedSlices) :
241+
// result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
242+
// else:
243+
// # TODO(yuanbyu, lukasr): Add support for SparseTensor.
244+
// raise TypeError("Type %s not supported" % type(grad))
245+
}
246+
247+
else
248+
{
249+
result = control_flow_ops.exit(grad);
250+
grad_ctxt.loop_exits.append(result);
251+
grad_ctxt.ExitResult(new[] { result });
252+
}
253+
254+
return new Tensor[] { result };
248255
}
249256

250257

0 commit comments

Comments
 (0)