@@ -17,7 +17,9 @@ limitations under the License.
1717using System ;
1818using System . Collections ;
1919using System . Collections . Generic ;
20+ using System . Linq ;
2021using static Tensorflow . Binding ;
22+ using util = Tensorflow . control_flow_util ;
2123
2224namespace Tensorflow . Operations . ControlFlows
2325{
@@ -56,6 +58,7 @@ public class GradLoopState
5658 public GradLoopState outer_grad_state => _outer_grad_state ;
5759
5860 Tensor _forward_index ;
61+ public Tensor forward_index => _forward_index ;
5962 Tensor _grad_index ;
6063
6164 Tensor [ ] _forward_loop_exits ;
@@ -152,63 +155,52 @@ public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
152155 /// <returns>The stack that contains the accumulated history of the tensor.</returns>
153156 public Tensor AddForwardAccumulator ( Tensor value , bool dead_branch = false )
154157 {
155- throw new NotImplementedException ( "AddForwardAccumulator" ) ;
156- // # curr_ctxt is the context that tf.gradients was called in.
157- // with self._forward_index.graph.as_default():
158- // curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
159- // with ops.control_dependencies(None):
160- // if curr_ctxt:
161- // curr_ctxt.Enter()
162- // with ops.colocate_with(value):
163- // # We only need to pass maximum_iterations to the stack if
164- // # we're inside an XLA context.
165- // if not util.IsInXLAContext(value.op):
166- // max_size = constant_op.constant(-1, dtypes.int32)
167- // else:
168- // max_size = GetMaxSizeFromNestedMaximumIterations(
169- // value, self.forward_context)
170- // acc = gen_data_flow_ops.stack_v2(
171- // max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
172- // if curr_ctxt:
173- // curr_ctxt.Exit()
174-
175- // # Make acc available in the forward context.
176- // enter_acc = self.forward_context.AddValue(acc)
177-
178- // # Add the stack_push op in the context of value.op.
179- // swap_enabled = self.forward_context.swap_memory
180- // value_ctxt = util.GetOutputContext(value.op)
181- // if value_ctxt == self.forward_context:
182- // # value is not nested in the forward context.
183- // self.forward_context.Enter()
184- // push = gen_data_flow_ops.stack_push_v2(
185- // enter_acc, value, swap_memory=swap_enabled)
186- // self.forward_context.Exit()
187- // # Protect stack push and order it before forward_index.
188- // self.forward_index.op._add_control_input(push.op)
189- // else:
190- // # value is in a cond context within the forward context.
191- // if not isinstance(value_ctxt, CondContext):
192- // raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
193- // if dead_branch:
194- // # The special case for creating a zero tensor for a dead
195- // # branch of a switch. See ControlFlowState.ZerosLike().
196- // value_ctxt.outer_context.Enter()
197- // push = gen_data_flow_ops.stack_push_v2(
198- // enter_acc, value, swap_memory=swap_enabled)
199- // value_ctxt.outer_context.Exit()
200- // push.op._set_control_flow_context(value_ctxt)
201- // else:
202- // value_ctxt.Enter()
203- // push = gen_data_flow_ops.stack_push_v2(
204- // enter_acc, value, swap_memory=swap_enabled)
205- // value_ctxt.Exit()
206- // # Protect stack push and order it before forward_sync.
207- // self.forward_sync._add_control_input(push.op)
208- // # Order stack push after the successor of forward_index
209- // add_op = self.forward_index.op.inputs[0].op
210- // push.op._add_control_input(add_op)
211- // return acc
158+ using ( _forward_index . graph . as_default ( ) )
159+ {
160+ var curr_ctxt = ops . get_default_graph ( ) . _get_control_flow_context ( ) ;
161+ return tf_with ( ops . control_dependencies ( null ) , delegate
162+ {
163+ Tensor acc = null ;
164+ Tensor push = null ;
165+ if ( curr_ctxt != null )
166+ curr_ctxt . Enter ( ) ;
167+ ops . colocate_with ( value ) ;
168+ {
169+ // We only need to pass maximum_iterations to the stack if
170+ // we're inside an XLA context.
171+ var max_size = constant_op . constant ( - 1 , dtypes . int32 ) ;
172+ acc = gen_data_flow_ops . stack_v2 (
173+ max_size : max_size , elem_type : value . dtype . as_base_dtype ( ) , name : "f_acc" ) ;
174+ }
175+ if ( curr_ctxt != null )
176+ curr_ctxt . Exit ( ) ;
177+
178+ // Make acc available in the forward context.
179+ var enter_acc = forward_context . AddValue ( acc ) ;
180+
181+ // Add the stack_push op in the context of value.op.
182+ var swap_enabled = forward_context . swap_memory ;
183+ var value_ctxt = util . GetOutputContext ( value . op ) ;
184+ if ( value_ctxt == forward_context )
185+ {
186+ // value is not nested in the forward context.
187+ forward_context . Enter ( ) ;
188+ push = gen_data_flow_ops . stack_push_v2 ( enter_acc , value , swap_memory : swap_enabled ) ;
189+ forward_context . Exit ( ) ;
190+ // Protect stack push and order it before forward_index.
191+ forward_index . op . _add_control_input ( push . op ) ;
192+ }
193+ else
194+ {
195+ throw new NotImplementedException ( "AddForwardAccumulator" ) ;
196+ }
197+
198+ // Order stack push after the successor of forward_index
199+ var add_op = forward_index . op . inputs [ 0 ] . op ;
200+ push . op . _add_control_input ( add_op ) ;
201+ return acc ;
202+ } ) ;
203+ }
212204 }
213205
214206 // """Add the getter for an accumulated value in the grad context.
@@ -225,6 +217,7 @@ public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
225217 // Returns:
226218 // The current value (the top of the stack).
227219 // """
220+
228221 public Tensor AddBackpropAccumulatedValue ( Tensor history_value , Tensor value , bool dead_branch = false )
229222 {
230223 throw new NotImplementedException ( ) ;
@@ -261,62 +254,50 @@ public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bo
261254 // return pop
262255 }
263256
264- // def GetRealValue(self, value):
265- // """Get the real value of `value`.
266-
267- // If backprop "uses" a value produced by forward inference, an accumulator
268- // is added in the forward loop to accumulate its values. We use the
269- // accumulated value. This method must be called in the grad loop context.
270- // `value` must be in forward and needed for backprop.
271-
272- // Args:
273- // value: A tensor to be captured.
274-
275- // Returns:
276- // The same tensor obtained from the saved history.
277- // """
278- // assert value.op.type not in ["Variable", "VariableV2"]
279- // real_value = self._history_map.get(value.name)
280- // if real_value is None:
281- // cur_value = value
282- // cur_grad_state = self
283- // while True:
284- // enter_op = util.GetLoopConstantEnter(cur_value)
285- // if enter_op:
286- // # Special case: cur_value comes from a constant Enter node.
287- // cur_value = enter_op.inputs[0]
288- // cur_grad_state = cur_grad_state.outer_grad_state
289- // if cur_grad_state is None:
290- // # We are now outside all nested loops for this gradient(),
291- // # so `value` is a loop invariant and there is no need to
292- // # save the history of value. Just make cur_value to enter
293- // # the right control flow context.
294- // real_value = self._grad_context.AddValue(cur_value)
295- // break
296- // elif constant_op.is_constant(cur_value):
297- // # If the value to be forwarded is a constant, clone the constant in
298- // # the gradient loop rather than using a stack.
299- // # TODO(phawkins): consider hoisting the constant out of the loop
300- // # instead.
301- // real_value = constant_op.constant(
302- // tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
303- // break
304- // else:
305- // # Record the history of this value in forward_ctxt.
306- // self._grad_context.Exit()
307- // history_value = cur_grad_state.AddForwardAccumulator(cur_value)
308- // self._grad_context.Enter()
309- // break
310-
311- // if real_value is None:
312- // # Add the stack pop op in the grad context.
313- // real_value = cur_grad_state.AddBackpropAccumulatedValue(
314- // history_value, cur_value)
315- // if cur_grad_state != self:
316- // real_value = self._grad_context.AddValue(real_value)
317- // self._history_map[value.name] = real_value
318- // return real_value
319-
320-
257+ /// <summary>
258+ /// Get the real value of `value`.
259+ /// </summary>
260+ /// <param name="value">A tensor to be captured.</param>
261+ /// <returns>The same tensor obtained from the saved history.</returns>
262+ public Tensor GetRealValue ( Tensor value )
263+ {
264+ Tensor real_value = null ;
265+ if ( real_value == null )
266+ {
267+ var cur_value = value ;
268+ var cur_grad_state = this ;
269+ Tensor history_value = null ;
270+ while ( true )
271+ {
272+ var enter_op = util . GetLoopConstantEnter ( cur_value ) ;
273+ if ( enter_op != null )
274+ {
275+ throw new NotImplementedException ( "GetRealValue" ) ;
276+ }
277+ else if ( constant_op . is_constant ( cur_value ) )
278+ {
279+ throw new NotImplementedException ( "GetRealValue" ) ;
280+ }
281+ else
282+ {
283+ // Record the history of this value in forward_ctxt.
284+ _grad_context . Exit ( ) ;
285+ history_value = cur_grad_state . AddForwardAccumulator ( cur_value ) ;
286+ _grad_context . Enter ( ) ;
287+ break ;
288+ }
289+ }
290+
291+ if ( real_value == null )
292+ {
293+ // Add the stack pop op in the grad context.
294+ real_value = cur_grad_state . AddBackpropAccumulatedValue ( history_value , cur_value ) ;
295+ if ( cur_grad_state != this )
296+ real_value = _grad_context . AddValue ( real_value ) ;
297+ }
298+ _history_map [ value . name ] = real_value ;
299+ }
300+ return real_value ;
301+ }
321302 }
322303}
0 commit comments