11using System ;
22using System . Collections . Generic ;
3+ using System . Linq ;
34using System . Text ;
45using distribute_lib = Tensorflow . Distribute ;
56
@@ -48,8 +49,10 @@ public Optimizer minimize(Tensor loss,
4849 /// <param name="gate_gradients"></param>
4950 public List < KeyValuePair < object , object > > compute_gradients ( Tensor loss ,
5051 List < RefVariable > var_list = null ,
52+ int ? aggregation_method = null ,
5153 GateGradientType gate_gradients = GateGradientType . GATE_OP ,
52- bool colocate_gradients_with_ops = false )
54+ bool colocate_gradients_with_ops = false ,
55+ List < Tensor > grad_loss = null )
5356 {
5457 int num_towers = 1 ;
5558 if ( distribute_lib . get_loss_reduction ( ) == VariableAggregationType . MEAN )
@@ -65,10 +68,13 @@ public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
6568 break ;
6669 }
6770
68- foreach ( var v in var_list )
69- {
71+ var processors = var_list . Select ( v => optimizer . _get_processor ( v ) ) ;
72+ var var_refs = processors . Select ( x => x . target ( ) ) . ToList ( ) ;
7073
71- }
74+ gradients_impl . gradients ( loss , var_refs , grad_ys : grad_loss ,
75+ gate_gradients : ( gate_gradients == GateGradientType . GATE_OP ) ,
76+ aggregation_method : aggregation_method ,
77+ colocate_gradients_with_ops : colocate_gradients_with_ops ) ;
7278
7379 return null ;
7480 }
0 commit comments