@@ -446,6 +446,84 @@ public override void AddOp(Operation op)
446446 return ( total_iterations , next_n ) ;
447447 }
448448
449+ /// <summary>
450+ /// Add an accumulation loop for every loop invariant.
451+ /// </summary>
452+ /// <param name="op">The Enter op for a loop invariant.</param>
453+ /// <param name="grad">The partial gradient of an iteration for a loop invariant.</param>
454+ /// <returns>The gradient for a loop invariant.</returns>
455+ public Tensor AddBackpropAccumulator ( Operation op , Tensor grad )
456+ {
457+ Tensor acc = null ;
458+ Exit ( ) ;
459+ // Create a zeros tensor with the right shape for acc. If we don't
460+ // know the full shape statically, we will have to get the shape
461+ // dynamically from the forward inference. Getting the shape right
462+ // for the zeros is only needed for the base case when the loop exits
463+ // without running any iterations.
464+ var shape = grad . TensorShape ;
465+ if ( shape . is_fully_defined ( ) )
466+ {
467+ if ( outer_context != null )
468+ outer_context . Enter ( ) ;
469+ acc = constant_op . constant ( 0 , grad . dtype , shape : shape , name : "b_acc" ) ;
470+ if ( outer_context != null )
471+ outer_context . Exit ( ) ;
472+ }
473+ else
474+ {
475+ var value = op . inputs [ 0 ] ;
476+ if ( outer_context is WhileContext wc )
477+ {
478+ // We are in a nested while loop.
479+ var forward_ctxt = grad_state . forward_context ;
480+ forward_ctxt . outer_context . Enter ( ) ;
481+ var zeros_shape = array_ops . shape_internal ( value , optimize : false ) ;
482+ forward_ctxt . outer_context . Exit ( ) ;
483+ var outer_grad_state = grad_state . outer_grad_state ;
484+ var history_zeros_shape = outer_grad_state . AddForwardAccumulator ( zeros_shape ) ;
485+ outer_context . Enter ( ) ;
486+ var real_shape = outer_grad_state . AddBackpropAccumulatedValue (
487+ history_zeros_shape , zeros_shape ) ;
488+ acc = array_ops . zeros ( real_shape , grad . dtype ) ;
489+ outer_context . Exit ( ) ;
490+ }
491+ else
492+ {
493+ if ( outer_context != null )
494+ outer_context . Enter ( ) ;
495+ var zeros_shape = array_ops . shape_internal ( value , optimize : false ) ;
496+ acc = array_ops . zeros ( zeros_shape , grad . dtype ) ;
497+ if ( outer_context != null )
498+ outer_context . Exit ( ) ;
499+ }
500+ throw new NotImplementedException ( "AddBackpropAccumulator" ) ;
501+ }
502+
503+ Enter ( ) ;
504+ AddName ( acc . name ) ;
505+ var enter_acc = _Enter (
506+ acc ,
507+ _name ,
508+ is_constant : false ,
509+ parallel_iterations : _parallel_iterations ,
510+ name : "b_acc" ) ;
511+ loop_enters . append ( enter_acc ) ;
512+ var merge_acc = merge ( new [ ] { enter_acc , enter_acc } , name : "b_acc" ) [ 0 ] ;
513+
514+ var switch_result = @switch ( merge_acc , _pivot ) ;
515+ var ( switch_acc_false , switch_acc_true ) = ( switch_result [ 0 ] , switch_result [ 1 ] ) ;
516+
517+ var add_acc = math_ops . add ( switch_acc_true , grad ) ;
518+ var next_acc = _NextIteration ( add_acc ) ;
519+ merge_acc . op . _update_input ( 1 , next_acc ) ;
520+
521+ var result_acc = exit ( switch_acc_false , name : "b_acc" ) ;
522+ loop_exits . append ( result_acc ) ;
523+ ExitResult ( new [ ] { result_acc } ) ;
524+ return result_acc ;
525+ }
526+
449527 /// <summary>
450528 /// Add the backprop loop that controls the iterations.
451529 /// </summary>
0 commit comments