@@ -45,7 +45,19 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
4545 switch ( op_ctxt )
4646 {
4747 case WhileContext cwhile :
48- throw new NotImplementedException ( "_SwitchGrad WhileContext" ) ;
48+ {
49+ var merge_grad = grad_ctxt . grad_state . switch_map . get ( op ) ;
50+ if ( merge_grad != null )
51+ throw new NotImplementedException ( "_SwitchGrad merge_grad != null" ) ;
52+ else if ( grads [ 0 ] != null )
53+ {
54+ merge_grad = merge ( new [ ] { grads [ 0 ] , grads [ 0 ] } , name : "b_switch" ) [ 0 ] ;
55+ grad_ctxt . grad_state . switch_map [ op ] = merge_grad ;
56+ return new Tensor [ ] { merge_grad , null } ;
57+ }
58+ else
59+ return new Tensor [ ] { null , null } ;
60+ }
4961 case CondContext ccond :
5062 {
5163 var zero_grad = grads [ 1 - op_ctxt . branch ] ;
@@ -74,7 +86,7 @@ public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
7486 /// <param name="inputs"></param>
7587 /// <param name="name"></param>
7688 /// <returns></returns>
77- internal static Tensor [ ] merge ( Tensor [ ] inputs , string name = null )
89+ internal static MergeOutput merge ( Tensor [ ] inputs , string name = null )
7890 {
7991 return tf_with ( ops . name_scope ( name , "Merge" , inputs ) , scope =>
8092 {
@@ -146,7 +158,7 @@ public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
146158 }
147159
148160 [ RegisterGradient ( "RefMerge" ) ]
149- public Tensor [ ] _RefMergeGrad ( Operation op , Tensor [ ] grads )
161+ public static Tensor [ ] _RefMergeGrad ( Operation op , Tensor [ ] grads )
150162 {
151163 return _MergeGrad ( op , grads ) ;
152164 }
@@ -155,43 +167,32 @@ public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads)
155167 /// Gradients for an exit op are calculated using an Enter op.
156168 /// </summary>
157169 [ RegisterGradient ( "Exit" ) ]
158- public Tensor [ ] _ExitGrad ( Operation op , Tensor [ ] grads )
170+ public static Tensor [ ] _ExitGrad ( Operation op , Tensor [ ] grads )
159171 {
160- throw new NotImplementedException ( "_ExitGrad" ) ;
161- // graph = ops.get_default_graph()
162- //# pylint: disable=protected-access
163- // op_ctxt = op._get_control_flow_context()
164- // grad_ctxt = graph._get_control_flow_context()
165- // # pylint: enable=protected-access
166- // if not grad_ctxt.back_prop:
167- // # The flag `back_prop` is set by users to suppress gradient
168- // # computation for this loop. If the attribute `back_prop` is false,
169- // # no gradient computation.
170- // return None
172+ var grad = grads [ 0 ] ;
173+ var graph = ops . get_default_graph ( ) ;
174+ var op_ctxt = op . _get_control_flow_context ( ) ;
175+ var grad_ctxt = graph . _get_control_flow_context ( ) as WhileContext ;
176+ // The flag `back_prop` is set by users to suppress gradient
177+ // computation for this loop. If the attribute `back_prop` is false,
178+ // no gradient computation.
179+ if ( ! grad_ctxt . back_prop )
180+ return null ;
181+
182+ if ( op_ctxt . grad_state != null )
183+ throw new TypeError ( "Second-order gradient for while loops not supported." ) ;
184+
185+ grad_ctxt . AddName ( grad . name ) ;
171186
172- // if op_ctxt.grad_state:
173- // raise TypeError("Second-order gradient for while loops not supported.")
187+ grad_ctxt . Enter ( ) ;
188+ var result = control_flow_ops . _Enter (
189+ grad , grad_ctxt . name , is_constant : false ,
190+ parallel_iterations : grad_ctxt . parallel_iterations ,
191+ name : "b_exit" ) ;
174192
175- // if isinstance(grad, ops.Tensor) :
176- // grad_ctxt.AddName(grad.name)
177- // else:
178- // if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
179- // raise TypeError("Type %s not supported" % type(grad))
180- // grad_ctxt.AddName(grad.values.name)
181- // grad_ctxt.AddName(grad.indices.name)
182- // dense_shape = grad.dense_shape
183- // if dense_shape is not None:
184- // grad_ctxt.AddName(dense_shape.name)
185- // grad_ctxt.Enter()
186- // # pylint: disable=protected-access
187- // result = control_flow_ops._Enter(
188- // grad, grad_ctxt.name, is_constant=False,
189- // parallel_iterations=grad_ctxt.parallel_iterations,
190- // name="b_exit")
191- // # pylint: enable=protected-access
192- // grad_ctxt.loop_enters.append(result)
193- // grad_ctxt.Exit()
194- // return result
193+ grad_ctxt . loop_enters . append ( result ) ;
194+ grad_ctxt . Exit ( ) ;
195+ return new [ ] { result } ;
195196 }
196197
197198 /// <summary>
@@ -200,15 +201,15 @@ public Tensor[] _ExitGrad(Operation op, Tensor[] grads)
200201 /// Note that the backprop next_iteration is added in switch grad.
201202 /// </summary>
202203 [ RegisterGradient ( "NextIteration" ) ]
203- public Tensor [ ] _NextIterationGrad ( object _ , Tensor [ ] grad )
204+ public static Tensor [ ] _NextIterationGrad ( Operation op , Tensor [ ] grads )
204205 {
205- return grad ;
206+ return grads ;
206207 }
207208
208209 [ RegisterGradient ( "RefNextIteration" ) ]
209- public Tensor [ ] _RefNextIterationGrad ( object _ , Tensor [ ] grad )
210+ public static Tensor [ ] _RefNextIterationGrad ( Operation op , Tensor [ ] grads )
210211 {
211- return grad ;
212+ return grads ;
212213 }
213214
214215 /// <summary>
@@ -218,33 +219,39 @@ public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad)
218219 /// For loop invariants, we need to add an accumulator loop.
219220 /// </summary>
220221 [ RegisterGradient ( "Enter" ) ]
221- public Tensor [ ] _EnterGrad ( Tensor op , Tensor [ ] grad )
222+ public static Tensor [ ] _EnterGrad ( Operation op , Tensor [ ] grads )
222223 {
223- throw new NotImplementedException ( "_EnterGrad" ) ;
224- // graph = ops.get_default_graph()
225- //# pylint: disable=protected-access
226- // grad_ctxt = graph._get_control_flow_context()
227- // # pylint: enable=protected-access
228- // if not grad_ctxt.back_prop:
229- // # Skip gradient computation, if the attribute `back_prop` is false.
230- // return grad
231- // if grad_ctxt.grad_state is None:
232- // # Pass the gradient through if we are not in a gradient while context.
233- // return grad
234- // if op.get_attr("is_constant"):
235- // # Add a gradient accumulator for each loop invariant.
236- // if isinstance(grad, ops.Tensor) :
237- // result = grad_ctxt.AddBackpropAccumulator(op, grad)
238- // elif isinstance(grad, ops.IndexedSlices) :
239- // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
240- // else:
241- // # TODO(yuanbyu, lukasr): Add support for SparseTensor.
242- // raise TypeError("Type %s not supported" % type(grad))
243- // else:
244- // result = exit(grad)
245- // grad_ctxt.loop_exits.append(result)
246- // grad_ctxt.ExitResult([result])
247- // return result
224+ Tensor result = null ;
225+ var grad = grads [ 0 ] ;
226+ var graph = ops . get_default_graph ( ) ;
227+ var grad_ctxt = graph . _get_control_flow_context ( ) as WhileContext ;
228+ if ( ! grad_ctxt . back_prop )
229+ // Skip gradient computation, if the attribute `back_prop` is false.
230+ return grads ;
231+ if ( grad_ctxt . grad_state == null )
232+ // Pass the gradient through if we are not in a gradient while context.
233+ return grads ;
234+ if ( op . get_attr < bool > ( "is_constant" ) )
235+ {
236+ throw new NotImplementedException ( "_EnterGrad is_constant" ) ;
237+ // Add a gradient accumulator for each loop invariant.
238+ // if isinstance(grad, ops.Tensor) :
239+ // result = grad_ctxt.AddBackpropAccumulator(op, grad)
240+ // elif isinstance(grad, ops.IndexedSlices) :
241+ // result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
242+ // else:
243+ // # TODO(yuanbyu, lukasr): Add support for SparseTensor.
244+ // raise TypeError("Type %s not supported" % type(grad))
245+ }
246+
247+ else
248+ {
249+ result = control_flow_ops . exit ( grad ) ;
250+ grad_ctxt . loop_exits . append ( result ) ;
251+ grad_ctxt . ExitResult ( new [ ] { result } ) ;
252+ }
253+
254+ return new Tensor [ ] { result } ;
248255 }
249256
250257
0 commit comments