Skip to content

Commit efed258

Browse files
committed
WhileContext
1 parent 59b7eb0 commit efed258

File tree

15 files changed

+256
-196
lines changed

15 files changed

+256
-196
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
1010
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)
1111

12-
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp_badge.png" width="200" height="200" align="right" /></a>
12+
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
1313

1414

1515
![tensors_flowing](docs/assets/tensors_flowing.gif)

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ namespace Tensorflow
3030
/// </summary>
3131
public static partial class Binding
3232
{
33+
public static T2 get<T1, T2>(this Dictionary<T1, T2> dict, T1 key)
34+
=> key == null ?
35+
default(T2) :
36+
(dict.ContainsKey(key) ? dict[key] : default(T2));
37+
38+
public static void add<T>(this IList<T> list, T element)
39+
=> list.Add(element);
40+
41+
public static void append<T>(this IList<T> list, T element)
42+
=> list.Add(element);
43+
44+
public static void extend<T>(this List<T> list, IEnumerable<T> elements)
45+
=> list.AddRange(elements);
46+
3347
private static string _tostring(object obj)
3448
{
3549
switch (obj)
@@ -81,6 +95,9 @@ public static int len(object a)
8195
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
8296
}
8397

98+
public static T[] list<T>(IEnumerable<T> list)
99+
=> list.ToArray();
100+
84101
public static IEnumerable<int> range(int end)
85102
{
86103
return Enumerable.Range(0, end);

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowState.cs

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,12 @@ public void EnterGradWhileContext(Operation op, bool before)
109109
grad_state.grad_context.Enter();
110110
}
111111

112-
// def ExitGradWhileContext(self, op, before):
113-
// """Exit the WhileContext for gradient computation."""
114-
// grad_state = self.GetGradState(op, before)
115-
// if grad_state:
116-
// grad_state.grad_context.Exit()
112+
public void ExitGradWhileContext(Operation op, bool before)
113+
{
114+
var grad_state = GetGradState(op, before);
115+
if (grad_state != null)
116+
grad_state.grad_context.Exit();
117+
}
117118

118119
// def AddWhileContext(self, op, between_op_list, between_ops):
119120
// """Add the grad state for the while loop that op belongs to.
@@ -287,51 +288,9 @@ public Tensor ZerosLikeForExit(Tensor val)
287288
return result;
288289
}
289290

290-
// def PostProcessing(self):
291-
// """Perform postprocessing at the end of gradients().
292-
293-
// We have created the gradient graph at this point. So this function
294-
// can be used to perform any postprocessing on the gradient graph.
295-
// We currently perform the following postprocessing:
296-
// 1. Patch the gradient graph if the output of a loop variable
297-
// doesn't depend on its input.
298-
// """
299-
// for _, grad_state in self._map.items():
300-
// for _, b_merge in grad_state.switch_map.items():
301-
// if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
302-
// # The value of this loop variable at iteration i+1 doesn't
303-
// # depend on its value at iteration i. So use zeros as the
304-
// # gradients for all iterations > 0.
305-
// dtype = b_merge.op.inputs[0].dtype
306-
// shape = b_merge.op.inputs[0].get_shape()
307-
// # pylint: disable=protected-access
308-
// if shape.is_fully_defined():
309-
// grad_state.grad_context.Enter()
310-
// # Create a zeros and use it for iterations > 0.
311-
// grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
312-
// next_grad_val = _NextIteration(grad_val)
313-
// grad_state.grad_context.Exit()
314-
// else:
315-
// # Create a zeros in the outer grad context.
316-
// outer_grad_ctxt = grad_state.grad_context.outer_context
317-
// if outer_grad_ctxt:
318-
// outer_grad_ctxt.Enter()
319-
// enter_grad_op = b_merge.op.inputs[0].op
320-
// enter_grad = enter_grad_op.inputs[0]
321-
// grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
322-
// grad_val = array_ops.zeros(grad_shape)
323-
// if outer_grad_ctxt:
324-
// outer_grad_ctxt.Exit()
325-
// # Use the zeros for iterations > 0.
326-
// grad_state.grad_context.Enter()
327-
// next_grad_val = _NextIteration(grad_val)
328-
// grad_state.grad_context.Exit()
329-
// b_merge.op._update_input(1, next_grad_val)
330-
// # pylint: enable=protected-access
331-
291+
public void PostProcessing()
292+
{
293+
throw new NotImplementedException("PostProcessing");
294+
}
332295
}
333-
334-
335-
336-
337296
}

src/TensorFlowNET.Core/Operations/ControlFlows/GradLoopState.cs

Lines changed: 95 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717
using System;
1818
using System.Collections;
1919
using System.Collections.Generic;
20+
using System.Linq;
2021
using static Tensorflow.Binding;
22+
using util = Tensorflow.control_flow_util;
2123

2224
namespace Tensorflow.Operations.ControlFlows
2325
{
@@ -56,6 +58,7 @@ public class GradLoopState
5658
public GradLoopState outer_grad_state => _outer_grad_state;
5759

5860
Tensor _forward_index;
61+
public Tensor forward_index => _forward_index;
5962
Tensor _grad_index;
6063

6164
Tensor[] _forward_loop_exits;
@@ -152,63 +155,52 @@ public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
152155
/// <returns>The stack that contains the accumulated history of the tensor.</returns>
153156
public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
154157
{
155-
throw new NotImplementedException("AddForwardAccumulator");
156-
// # curr_ctxt is the context that tf.gradients was called in.
157-
// with self._forward_index.graph.as_default():
158-
// curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
159-
// with ops.control_dependencies(None):
160-
// if curr_ctxt:
161-
// curr_ctxt.Enter()
162-
// with ops.colocate_with(value):
163-
// # We only need to pass maximum_iterations to the stack if
164-
// # we're inside an XLA context.
165-
// if not util.IsInXLAContext(value.op):
166-
// max_size = constant_op.constant(-1, dtypes.int32)
167-
// else:
168-
// max_size = GetMaxSizeFromNestedMaximumIterations(
169-
// value, self.forward_context)
170-
// acc = gen_data_flow_ops.stack_v2(
171-
// max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
172-
// if curr_ctxt:
173-
// curr_ctxt.Exit()
174-
175-
// # Make acc available in the forward context.
176-
// enter_acc = self.forward_context.AddValue(acc)
177-
178-
// # Add the stack_push op in the context of value.op.
179-
// swap_enabled = self.forward_context.swap_memory
180-
// value_ctxt = util.GetOutputContext(value.op)
181-
// if value_ctxt == self.forward_context:
182-
// # value is not nested in the forward context.
183-
// self.forward_context.Enter()
184-
// push = gen_data_flow_ops.stack_push_v2(
185-
// enter_acc, value, swap_memory=swap_enabled)
186-
// self.forward_context.Exit()
187-
// # Protect stack push and order it before forward_index.
188-
// self.forward_index.op._add_control_input(push.op)
189-
// else:
190-
// # value is in a cond context within the forward context.
191-
// if not isinstance(value_ctxt, CondContext):
192-
// raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
193-
// if dead_branch:
194-
// # The special case for creating a zero tensor for a dead
195-
// # branch of a switch. See ControlFlowState.ZerosLike().
196-
// value_ctxt.outer_context.Enter()
197-
// push = gen_data_flow_ops.stack_push_v2(
198-
// enter_acc, value, swap_memory=swap_enabled)
199-
// value_ctxt.outer_context.Exit()
200-
// push.op._set_control_flow_context(value_ctxt)
201-
// else:
202-
// value_ctxt.Enter()
203-
// push = gen_data_flow_ops.stack_push_v2(
204-
// enter_acc, value, swap_memory=swap_enabled)
205-
// value_ctxt.Exit()
206-
// # Protect stack push and order it before forward_sync.
207-
// self.forward_sync._add_control_input(push.op)
208-
// # Order stack push after the successor of forward_index
209-
// add_op = self.forward_index.op.inputs[0].op
210-
// push.op._add_control_input(add_op)
211-
// return acc
158+
using (_forward_index.graph.as_default())
159+
{
160+
var curr_ctxt = ops.get_default_graph()._get_control_flow_context();
161+
return tf_with(ops.control_dependencies(null), delegate
162+
{
163+
Tensor acc = null;
164+
Tensor push = null;
165+
if (curr_ctxt != null)
166+
curr_ctxt.Enter();
167+
ops.colocate_with(value);
168+
{
169+
// We only need to pass maximum_iterations to the stack if
170+
// we're inside an XLA context.
171+
var max_size = constant_op.constant(-1, dtypes.int32);
172+
acc = gen_data_flow_ops.stack_v2(
173+
max_size: max_size, elem_type: value.dtype.as_base_dtype(), name: "f_acc");
174+
}
175+
if (curr_ctxt != null)
176+
curr_ctxt.Exit();
177+
178+
// Make acc available in the forward context.
179+
var enter_acc = forward_context.AddValue(acc);
180+
181+
// Add the stack_push op in the context of value.op.
182+
var swap_enabled = forward_context.swap_memory;
183+
var value_ctxt = util.GetOutputContext(value.op);
184+
if(value_ctxt == forward_context)
185+
{
186+
// value is not nested in the forward context.
187+
forward_context.Enter();
188+
push = gen_data_flow_ops.stack_push_v2(enter_acc, value, swap_memory: swap_enabled);
189+
forward_context.Exit();
190+
// Protect stack push and order it before forward_index.
191+
forward_index.op._add_control_input(push.op);
192+
}
193+
else
194+
{
195+
throw new NotImplementedException("AddForwardAccumulator");
196+
}
197+
198+
// Order stack push after the successor of forward_index
199+
var add_op = forward_index.op.inputs[0].op;
200+
push.op._add_control_input(add_op);
201+
return acc;
202+
});
203+
}
212204
}
213205

214206
// """Add the getter for an accumulated value in the grad context.
@@ -225,6 +217,7 @@ public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
225217
// Returns:
226218
// The current value (the top of the stack).
227219
// """
220+
228221
public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch= false)
229222
{
230223
throw new NotImplementedException();
@@ -261,62 +254,50 @@ public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bo
261254
// return pop
262255
}
263256

264-
// def GetRealValue(self, value):
265-
// """Get the real value of `value`.
266-
267-
// If backprop "uses" a value produced by forward inference, an accumulator
268-
// is added in the forward loop to accumulate its values. We use the
269-
// accumulated value. This method must be called in the grad loop context.
270-
// `value` must be in forward and needed for backprop.
271-
272-
// Args:
273-
// value: A tensor to be captured.
274-
275-
// Returns:
276-
// The same tensor obtained from the saved history.
277-
// """
278-
// assert value.op.type not in ["Variable", "VariableV2"]
279-
// real_value = self._history_map.get(value.name)
280-
// if real_value is None:
281-
// cur_value = value
282-
// cur_grad_state = self
283-
// while True:
284-
// enter_op = util.GetLoopConstantEnter(cur_value)
285-
// if enter_op:
286-
// # Special case: cur_value comes from a constant Enter node.
287-
// cur_value = enter_op.inputs[0]
288-
// cur_grad_state = cur_grad_state.outer_grad_state
289-
// if cur_grad_state is None:
290-
// # We are now outside all nested loops for this gradient(),
291-
// # so `value` is a loop invariant and there is no need to
292-
// # save the history of value. Just make cur_value to enter
293-
// # the right control flow context.
294-
// real_value = self._grad_context.AddValue(cur_value)
295-
// break
296-
// elif constant_op.is_constant(cur_value):
297-
// # If the value to be forwarded is a constant, clone the constant in
298-
// # the gradient loop rather than using a stack.
299-
// # TODO(phawkins): consider hoisting the constant out of the loop
300-
// # instead.
301-
// real_value = constant_op.constant(
302-
// tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
303-
// break
304-
// else:
305-
// # Record the history of this value in forward_ctxt.
306-
// self._grad_context.Exit()
307-
// history_value = cur_grad_state.AddForwardAccumulator(cur_value)
308-
// self._grad_context.Enter()
309-
// break
310-
311-
// if real_value is None:
312-
// # Add the stack pop op in the grad context.
313-
// real_value = cur_grad_state.AddBackpropAccumulatedValue(
314-
// history_value, cur_value)
315-
// if cur_grad_state != self:
316-
// real_value = self._grad_context.AddValue(real_value)
317-
// self._history_map[value.name] = real_value
318-
// return real_value
319-
320-
257+
/// <summary>
258+
/// Get the real value of `value`.
259+
/// </summary>
260+
/// <param name="value">A tensor to be captured.</param>
261+
/// <returns>The same tensor obtained from the saved history.</returns>
262+
public Tensor GetRealValue(Tensor value)
263+
{
264+
Tensor real_value = null;
265+
if(real_value == null)
266+
{
267+
var cur_value = value;
268+
var cur_grad_state = this;
269+
Tensor history_value = null;
270+
while (true)
271+
{
272+
var enter_op = util.GetLoopConstantEnter(cur_value);
273+
if(enter_op != null)
274+
{
275+
throw new NotImplementedException("GetRealValue");
276+
}
277+
else if (constant_op.is_constant(cur_value))
278+
{
279+
throw new NotImplementedException("GetRealValue");
280+
}
281+
else
282+
{
283+
// Record the history of this value in forward_ctxt.
284+
_grad_context.Exit();
285+
history_value = cur_grad_state.AddForwardAccumulator(cur_value);
286+
_grad_context.Enter();
287+
break;
288+
}
289+
}
290+
291+
if(real_value == null)
292+
{
293+
// Add the stack pop op in the grad context.
294+
real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, cur_value);
295+
if (cur_grad_state != this)
296+
real_value = _grad_context.AddValue(real_value);
297+
}
298+
_history_map[value.name] = real_value;
299+
}
300+
return real_value;
301+
}
321302
}
322303
}

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,9 @@ public override Tensor AddValue(Tensor val)
530530
}
531531
if(forward_ctxt == grad_ctxt.grad_state.forward_context)
532532
{
533-
throw new NotImplementedException("forward_ctxt == grad_ctxt.grad_state.forward_context");
534-
/*real_val = grad_ctxt.grad_state.GetRealValue(val);
533+
var real_val = grad_ctxt.grad_state.GetRealValue(val);
535534
_external_values[val.name] = real_val;
536-
return real_val;*/
535+
return real_val;
537536
}
538537
}
539538
}

0 commit comments

Comments
 (0)