@@ -14,13 +14,20 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using System ;
18+ using System . Linq ;
19+ using System . Collections . Generic ;
20+ using util = Tensorflow . control_flow_util ;
21+ using static Tensorflow . Binding ;
22+
1723namespace Tensorflow . Operations . ControlFlows
1824{
1925 /// <summary>
2026 /// Maintain the mapping from the loops to their grad states.
2127 /// </summary>
2228 public class ControlFlowState
2329 {
30+ Dictionary < ControlFlowContext , GradLoopState > _map ;
2431 //class ControlFlowState(object):
2532 // """Maintain the mapping from the loops to their grad states."""
2633
@@ -40,51 +47,67 @@ public class ControlFlowState
4047 // return self._map.get(forward_ctxt)
4148 // return None
4249
43- // def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
44- // """Process all the "unused" loop exits.
45-
46- // The "unused" exits of the loops are added to `unused_exits`. An exit is
47- // unused if its pending_count is 0. If there is an exit with real gradient,
48- // all these deferred exits will enter the backprop loop with zero gradient.
49- // Otherwise, they will enter the backprop loop with None. As an example,
50- // people often write:
51-
52- // ```python
53- // v1, _ = tf.while_loop(p, b, [x1, x2])
54- // result = gradients(v1, x1)
55- // ```
56-
57- // The exit node for x2 is not included by the betweenness analysis. But we
58- // need to backprop x2 if x2 is involved in computing v1.
59-
60- // Args:
61- // pending_count: The number of backprop inputs for every op.
62- // to_ops_set: The set of ops for ys in gradients(ys, xs)
63-
64- // Returns:
65- // The set of unused loop exits that we know at this point we need
66- // to backprop.
67- // """
68- // loop_exits = []
69- // for grad_state in self._map.values():
70- // for y in grad_state.forward_loop_exits:
71- // if pending_count[y.op] == 0:
72- // grad_state.pending_exits_count -= 1
73- // if y.op not in to_ops_set:
74- // grad_state.unused_exits.append(y)
75- // if grad_state.pending_exits_count == 0:
76- // loop_exits.extend(grad_state.unused_exits)
77- // # Need to include Enters in backprop for higher-order gradients.
78- // for y in grad_state.forward_context.loop_enters:
79- // if pending_count[y.op] == 0:
80- // pending_count[y.op] = 1
81- // return loop_exits
82-
83- // def EnterGradWhileContext(self, op, before):
84- // """Enter the WhileContext for gradient computation."""
85- // grad_state = self.GetGradState(op, before)
86- // if grad_state:
87- // grad_state.grad_context.Enter()
50+ public ControlFlowState ( )
51+ {
52+ _map = new Dictionary < ControlFlowContext , GradLoopState > ( ) ;
53+ }
54+
55+ /// <summary>
56+ /// Return the grad state for this op if it's in a forward loop context.
57+ /// </summary>
58+ /// <param name="op"></param>
59+ /// <param name="before"></param>
60+ /// <returns></returns>
61+ public GradLoopState GetGradState ( Operation op , bool before )
62+ {
63+ ControlFlowContext forward_ctxt = null ;
64+ if ( before && util . IsLoopExit ( op ) )
65+ {
66+ forward_ctxt = op . _get_control_flow_context ( ) ;
67+ forward_ctxt = forward_ctxt . outer_context ;
68+ if ( forward_ctxt != null )
69+ forward_ctxt = forward_ctxt . GetWhileContext ( ) ;
70+ }
71+ else
72+ forward_ctxt = util . GetWhileContext ( op ) ;
73+ if ( forward_ctxt != null )
74+ return _map . get ( forward_ctxt ) ;
75+ return null ;
76+ }
77+
78+ public Tensor [ ] ProcessUnusedLoopExits ( Dictionary < string , int > pending_count , List < Operation > to_ops_set )
79+ {
80+ var loop_exits = new List < Tensor > ( ) ;
81+ foreach ( var grad_state in _map . Values )
82+ {
83+ foreach ( var y in grad_state . forward_loop_exits )
84+ {
85+ if ( ! pending_count . ContainsKey ( y . op . name ) )
86+ {
87+ grad_state . pending_exits_count -= 1 ;
88+ if ( ! to_ops_set . Contains ( y . op ) )
89+ grad_state . unused_exits . append ( y ) ;
90+ if ( grad_state . pending_exits_count == 0 )
91+ loop_exits . extend ( grad_state . unused_exits ) ;
92+ }
93+ }
94+
95+ foreach ( var y in grad_state . forward_context . loop_enters )
96+ {
97+ if ( ! pending_count . ContainsKey ( y . op . name ) )
98+ pending_count [ y . op . name ] = 1 ;
99+ }
100+ }
101+
102+ return loop_exits . ToArray ( ) ;
103+ }
104+
105+ public void EnterGradWhileContext ( Operation op , bool before )
106+ {
107+ var grad_state = GetGradState ( op , before ) ;
108+ if ( grad_state != null )
109+ grad_state . grad_context . Enter ( ) ;
110+ }
88111
89112 // def ExitGradWhileContext(self, op, before):
90113 // """Exit the WhileContext for gradient computation."""
@@ -118,6 +141,32 @@ public class ControlFlowState
118141 // if loop_exit.op not in between_ops:
119142 // between_ops.add(loop_exit.op)
120143 // between_op_list.append(loop_exit.op)
144+ public void AddWhileContext ( Operation op , List < Operation > between_op_list , List < Operation > between_ops )
145+ {
146+ var forward_ctxt = op . GetWhileContext ( ) ;
147+ var grad_state = _map . ContainsKey ( forward_ctxt ) ? _map [ forward_ctxt ] : null ;
148+ if ( grad_state == null )
149+ {
150+ GradLoopState outer_grad_state = null ;
151+ var outer_forward_ctxt = forward_ctxt . outer_context ;
152+ if ( outer_forward_ctxt != null )
153+ outer_forward_ctxt = outer_forward_ctxt . GetWhileContext ( ) ;
154+ if ( outer_forward_ctxt != null )
155+ outer_grad_state = _map [ outer_forward_ctxt ] ;
156+ grad_state = new GradLoopState ( forward_ctxt , outer_grad_state ) ;
157+ _map [ forward_ctxt ] = grad_state ;
158+
159+ // We need to include all exits of a loop for backprop.
160+ foreach ( var loop_exit in grad_state . forward_loop_exits )
161+ {
162+ if ( ! between_ops . Contains ( loop_exit . op ) )
163+ {
164+ between_ops . add ( loop_exit . op ) ;
165+ between_op_list . append ( loop_exit . op ) ;
166+ }
167+ }
168+ }
169+ }
121170
122171 // def ZerosLikeForExit(self, val):
123172 // """Create zeros_like gradient for a loop exit.
@@ -174,70 +223,69 @@ public class ControlFlowState
174223 // result = array_ops.zeros_like(val, optimize=False)
175224 // return result
176225
177- // def ZerosLike(self, op, index):
178- // """Create zeros_like for the specified output of an op.
179-
180- // If op is in a while loop that is part of gradients(), this method
181- // must be called in its grad loop context.
182-
183- // Args:
184- // op: A tensorflow operation.
185- // index: the index for a specific output of the op.
186-
187- // Returns:
188- // A zero tensor of the same shape of op.outputs[index].
189- // """
190- // if util.IsLoopSwitch(op):
191- // return None
192- // if op.graph._building_function: # pylint: disable=protected-access
193- // # The optimization here is tricky to apply to functions
194- // return array_ops.zeros_like(op.outputs[index])
195- // dead_branch = util.IsSwitch(op)
196- // forward_ctxt = _GetWhileContext(op)
197- // grad_state = self._map.get(forward_ctxt)
198- // if grad_state is None:
199- // # op is not in a while loop that is part of gradients().
200- // return ZerosLikeOutsideLoop(op, index)
201- // op_ctxt = op._get_control_flow_context()
202- // val = ops.convert_to_tensor(op.outputs[index], name="tensor")
203- // shape = val.get_shape()
204- // if shape.is_fully_defined():
205- // # If the shape is known statically, just create a zero tensor with
206- // # the right shape in the grad loop context.
207- // result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
208- // if dead_branch:
209- // # op is a cond switch. Guard the zero tensor with a switch.
210- // pred = grad_state.history_map.get(op_ctxt.pred.name)
211- // branch = op_ctxt.branch
212- // result = _SwitchRefOrTensor(result, pred)[1 - branch]
213- // else:
214- // # Unknown shape so keep a history of the shape at runtime.
215- // if dead_branch:
216- // # Need to add a special switch to guard the value.
217- // pred = op_ctxt.pred
218- // branch = op_ctxt.branch
219- // op_ctxt.outer_context.Enter()
220- // val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
221- // zeros_shape = array_ops.shape_internal(val, optimize=False)
222- // op_ctxt.outer_context.Exit()
223- // val.op._set_control_flow_context(op_ctxt)
224- // zeros_shape.op._set_control_flow_context(op_ctxt)
225- // else:
226- // op_ctxt.Enter()
227- // zeros_shape = array_ops.shape_internal(val, optimize=False)
228- // op_ctxt.Exit()
229-
230- // # Add forward accumulator for shape.
231- // grad_state.grad_context.Exit()
232- // history_zeros_shape = grad_state.AddForwardAccumulator(
233- // zeros_shape, dead_branch=dead_branch)
234- // grad_state.grad_context.Enter()
235-
236- // # Create a zero tensor with the right shape.
237- // shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
238- // zeros_shape, dead_branch)
239- // result = array_ops.zeros(shape, val.dtype)
240- // return result
226+ public Tensor ZerosLike ( Operation op , int index )
227+ {
228+ if ( util . IsLoopSwitch ( op ) )
229+ return null ;
230+ if ( op . graph . building_function )
231+ return array_ops . zeros_like ( op . outputs [ index ] ) ;
232+ var dead_branch = util . IsSwitch ( op ) ;
233+ var forward_ctxt = util . GetWhileContext ( op ) ;
234+ var grad_state = _map . get ( forward_ctxt ) ;
235+ // op is not in a while loop that is part of gradients().
236+ if ( grad_state == null )
237+ return ZerosLikeOutsideLoop ( op , index ) ;
238+ throw new NotImplementedException ( "ZerosLike" ) ;
239+ }
240+
241+ public Tensor ZerosLikeOutsideLoop ( Operation op , int index )
242+ {
243+ var val = op . outputs [ index ] ;
244+ if ( ! util . IsSwitch ( op ) )
245+ {
246+ if ( val . dtype == dtypes . resource )
247+ throw new NotImplementedException ( "ZerosLikeOutsideLoop" ) ;
248+ /*return array_ops.zeros(
249+ gen_resource_variable_ops.variable_shape(val),
250+ dtype: default_gradient.get_zeros_dtype(val));*/
251+ return array_ops . zeros_like ( val , optimize : false ) ;
252+ }
253+ else
254+ throw new NotImplementedException ( "ZerosLikeOutsideLoop" ) ;
255+ }
256+
257+ /// <summary>
258+ /// Create zeros_like gradient for a loop exit.
259+ /// </summary>
260+ /// <param name="val"></param>
261+ /// <returns></returns>
262+ public Tensor ZerosLikeForExit ( Tensor val )
263+ {
264+ Tensor result = null ;
265+ var val_shape = val . TensorShape ;
266+ var forward_ctxt = val . op . _get_control_flow_context ( ) ;
267+ var outer_forward_ctxt = forward_ctxt . outer_context ;
268+ if ( outer_forward_ctxt != null )
269+ outer_forward_ctxt = outer_forward_ctxt . GetWhileContext ( ) ;
270+ GradLoopState outer_grad_state = null ;
271+ if ( outer_forward_ctxt != null )
272+ outer_grad_state = _map . get ( outer_forward_ctxt ) ;
273+ // This is a nested loop.
274+ if ( outer_grad_state != null )
275+ {
276+ throw new NotImplementedException ( "ZerosLikeForExit" ) ;
277+ }
278+ else
279+ {
280+ // If the shape is known statically, just create a zero tensor
281+ // with the right shape.
282+ if ( val_shape . is_fully_defined ( ) )
283+ result = array_ops . zeros ( val_shape . dims , val . dtype ) ;
284+ else
285+ result = array_ops . zeros_like ( val , optimize : false ) ;
286+ }
287+ return result ;
288+ }
241289
242290 // def PostProcessing(self):
243291 // """Perform postprocessing at the end of gradients().
0 commit comments