@@ -137,7 +137,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
137137 if ( loop_state != null )
138138 ;
139139 else
140- out_grads [ i ] = control_flow_ops . ZerosLikeOutsideLoop ( op , i ) ;
140+ out_grads [ i ] = new List < Tensor > { control_flow_ops . ZerosLikeOutsideLoop ( op , i ) } ;
141141 }
142142 }
143143
@@ -146,7 +146,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
146146 string name1 = scope1 ;
147147 if ( grad_fn != null )
148148 {
149- in_grads = _MaybeCompile ( grad_scope , op , out_grads , null , grad_fn ) ;
149+ in_grads = _MaybeCompile ( grad_scope , op , out_grads [ 0 ] . ToArray ( ) , null , grad_fn ) ;
150150 _VerifyGeneratedGradients ( in_grads , op ) ;
151151 }
152152
@@ -310,10 +310,9 @@ private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs)
310310 yield return op . inputs [ i ] ;
311311 }
312312
313- private static Tensor [ ] _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid , object loop_state , int aggregation_method = 0 )
313+ private static List < List < Tensor > > _AggregatedGrads ( Dictionary < string , List < List < Tensor > > > grads , Operation op , string gradient_uid , object loop_state , int aggregation_method = 0 )
314314 {
315315 var out_grads = _GetGrads ( grads , op ) ;
316- var return_grads = new Tensor [ out_grads . Count ] ;
317316
318317 foreach ( var ( i , out_grad ) in enumerate ( out_grads ) )
319318 {
@@ -334,21 +333,21 @@ private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>>
334333 throw new ValueError ( "_AggregatedGrads out_grad.Length == 0" ) ;
335334 }
336335
337- return_grads [ i ] = out_grad [ 0 ] ;
336+ out_grads [ i ] = out_grad ;
338337 }
339338 else
340339 {
341340 used = "add_n" ;
342- return_grads [ i ] = _MultiDeviceAddN ( out_grad . ToArray ( ) , gradient_uid ) ;
341+ out_grads [ i ] = new List < Tensor > { _MultiDeviceAddN ( out_grad . ToArray ( ) , gradient_uid ) } ;
343342 }
344343 }
345344 else
346345 {
347- return_grads [ i ] = null ;
346+ out_grads [ i ] = null ;
348347 }
349348 }
350349
351- return return_grads ;
350+ return out_grads ;
352351 }
353352
354353 /// <summary>
@@ -362,18 +361,18 @@ private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid
362361 // Basic function structure comes from control_flow_ops.group().
363362 // Sort tensors according to their devices.
364363 var tensors_on_device = new Dictionary < string , List < Tensor > > ( ) ;
365-
364+
366365 foreach ( var tensor in tensor_list )
367366 {
368367 if ( ! tensors_on_device . ContainsKey ( tensor . Device ) )
369368 tensors_on_device [ tensor . Device ] = new List < Tensor > ( ) ;
370369
371370 tensors_on_device [ tensor . Device ] . Add ( tensor ) ;
372371 }
373-
372+
374373 // For each device, add the tensors on that device first.
375374 var summands = new List < Tensor > ( ) ;
376- foreach ( var dev in tensors_on_device . Keys )
375+ foreach ( var dev in tensors_on_device . Keys )
377376 {
378377 var tensors = tensors_on_device [ dev ] ;
379378 ops . _colocate_with_for_gradient ( tensors [ 0 ] . op , gradient_uid , ignore_existing : true ) ;
0 commit comments