@@ -78,6 +78,26 @@ public class GradLoopState
7878 /// </summary>
7979 public int pending_exits_count { get ; set ; }
8080
81+ Operation _grad_sync ;
82+ public Operation grad_sync
83+ {
84+ get
85+ {
86+ if ( _grad_sync == null )
87+ {
88+ tf_with ( ops . control_dependencies ( null ) , delegate
89+ {
90+ _grad_sync = gen_control_flow_ops . control_trigger ( name : "b_sync" ) ;
91+ } ) ;
92+ _grad_sync . _set_control_flow_context ( _grad_context ) ;
93+ _grad_index . op . _add_control_input ( _grad_sync ) ;
94+ if ( _grad_context . outer_context != null )
95+ _grad_context . outer_context . AddInnerOp ( _grad_sync ) ;
96+ }
97+ return _grad_sync ;
98+ }
99+ }
100+
81101 public GradLoopState ( WhileContext forward_ctxt , GradLoopState outer_grad_state_ )
82102 {
83103 // Information needed by backprop.
@@ -155,7 +175,7 @@ public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
155175 /// <returns>The stack that contains the accumulated history of the tensor.</returns>
156176 public Tensor AddForwardAccumulator ( Tensor value , bool dead_branch = false )
157177 {
158- using ( _forward_index . graph . as_default ( ) )
178+ _forward_index . graph . as_default ( ) ;
159179 {
160180 var curr_ctxt = ops . get_default_graph ( ) . _get_control_flow_context ( ) ;
161181 return tf_with ( ops . control_dependencies ( null ) , delegate
@@ -220,38 +240,33 @@ public Tensor AddForwardAccumulator(Tensor value, bool dead_branch = false)
220240
221241 public Tensor AddBackpropAccumulatedValue ( Tensor history_value , Tensor value , bool dead_branch = false )
222242 {
223- throw new NotImplementedException ( ) ;
224- // history_ctxt = history_value.op._get_control_flow_context()
225- // # Find the cond context that controls history_value if any.
226- // cond_ctxt = None
227- // value_ctxt = value.op._get_control_flow_context()
228- // while value_ctxt and value_ctxt != history_ctxt:
229- // if isinstance(value_ctxt, CondContext):
230- // cond_ctxt = value_ctxt
231- // break
232- // value_ctxt = value_ctxt.outer_context
233- // with ops.control_dependencies(None):
234- // self.grad_context.Enter()
235- // if cond_ctxt:
236- // # Guard stack pop with a switch if it is controlled by a cond.
237- // grad_state = self
238- // pred = None
239- // while pred is None and grad_state:
240- // pred = grad_state.history_map.get(cond_ctxt.pred.name)
241- // grad_state = grad_state.outer_grad_state
242- // if pred is None:
243- // pred = cond_ctxt.pred
244- // branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
245- // history_value = _SwitchRefOrTensor(history_value, pred)[branch]
246- // pop = gen_data_flow_ops.stack_pop_v2(history_value,
247- // value.dtype.base_dtype)
248- // pop.set_shape(value.get_shape())
249- // self.grad_context.Exit()
250- // parallel_iterations = self.grad_context.parallel_iterations
251- // if parallel_iterations > 1:
252- // # All pops are ordered after pivot_for_body and before grad_sync.
253- // self.grad_sync._add_control_input(pop.op)
254- // return pop
243+ var history_ctxt = history_value . op . _get_control_flow_context ( ) ;
244+ // Find the cond context that controls history_value if any.
245+ CondContext cond_ctxt = null ;
246+ Tensor pop = null ;
247+ var value_ctxt = value . op . _get_control_flow_context ( ) ;
248+ while ( value_ctxt != null && value_ctxt != history_ctxt )
249+ {
250+ if ( value_ctxt is CondContext cc )
251+ cond_ctxt = cc ;
252+ value_ctxt = value_ctxt . outer_context ;
253+ }
254+ tf_with ( ops . control_dependencies ( null ) , delegate
255+ {
256+ grad_context . Enter ( ) ;
257+ if ( cond_ctxt != null )
258+ {
259+ throw new NotImplementedException ( "AddBackpropAccumulatedValue" ) ;
260+ }
261+ pop = gen_data_flow_ops . stack_pop_v2 ( history_value , value . dtype . as_base_dtype ( ) ) ;
262+ pop . set_shape ( value . TensorShape ) ;
263+ grad_context . Exit ( ) ;
264+ } ) ;
265+ var parallel_iterations = grad_context . parallel_iterations ;
266+ if ( parallel_iterations > 1 )
267+ // All pops are ordered after pivot_for_body and before grad_sync.
268+ grad_sync . _add_control_input ( pop . op ) ;
269+ return pop ;
255270 }
256271
257272 /// <summary>
@@ -272,11 +287,28 @@ public Tensor GetRealValue(Tensor value)
272287 var enter_op = util . GetLoopConstantEnter ( cur_value ) ;
273288 if ( enter_op != null )
274289 {
275- throw new NotImplementedException ( "GetRealValue" ) ;
290+ // Special case: cur_value comes from a constant Enter node.
291+ cur_value = enter_op . inputs [ 0 ] ;
292+ cur_grad_state = cur_grad_state . outer_grad_state ;
293+ if ( cur_grad_state == null )
294+ {
295+ // We are now outside all nested loops for this gradient(),
296+ // so `value` is a loop invariant and there is no need to
297+ // save the history of value. Just make cur_value to enter
298+ // the right control flow context.
299+ real_value = _grad_context . AddValue ( cur_value ) ;
300+ break ;
301+ }
276302 }
277303 else if ( constant_op . is_constant ( cur_value ) )
278304 {
279- throw new NotImplementedException ( "GetRealValue" ) ;
305+ // We are now outside all nested loops for this gradient(),
306+ // so `value` is a loop invariant and there is no need to
307+ // save the history of value. Just make cur_value to enter
308+ // the right control flow context.
309+ real_value = constant_op . constant (
310+ tensor_util . constant_value ( cur_value ) , dtype : cur_value . dtype ) ;
311+ break ;
280312 }
281313 else
282314 {
0 commit comments