@@ -107,7 +107,7 @@ public Optimizer(Tensor learning_rate, bool use_locking, string name = null)
107107 /// </returns>
108108 public Operation minimize ( Tensor loss ,
109109 IVariableV1 global_step = null ,
110- List < ResourceVariable > var_list = null ,
110+ List < IVariableV1 > var_list = null ,
111111 GateGradientType gate_gradients = GateGradientType . GATE_OP ,
112112 int ? aggregation_method = null ,
113113 bool colocate_gradients_with_ops = false , string name = null , Tensor grad_loss = null )
@@ -142,17 +142,17 @@ public Operation minimize(Tensor loss,
142142 /// <returns>
143143 /// An `Operation` that applies the specified gradients. If `global_step`
144144 /// was not None, that operation also increments `global_step`.</returns>
145- public Operation apply_gradients ( Tuple < Tensor , ResourceVariable > [ ] grads_and_vars , IVariableV1 global_step = null , string name = null )
145+ public Operation apply_gradients ( Tuple < Tensor , IVariableV1 > [ ] grads_and_vars , IVariableV1 global_step = null , string name = null )
146146 {
147147 // No DistributionStrategy case.
148- var converted_grads_and_vars = new List < ( Tensor , ResourceVariable , _OptimizableVariable ) > ( ) ;
148+ var converted_grads_and_vars = new List < ( Tensor , IVariableV1 , _OptimizableVariable ) > ( ) ;
149149 foreach ( var ( g , v ) in grads_and_vars )
150150 {
151151 if ( g != null )
152152 {
153153 // Convert the grad to Tensor or IndexedSlices if necessary.
154154 var gR = ops . convert_to_tensor_or_indexed_slices ( g ) ;
155- var p = optimizer . _get_processor ( v ) ;
155+ var p = optimizer . _get_processor ( v as ResourceVariable ) ;
156156 converted_grads_and_vars . Add ( ( gR , v , p ) ) ;
157157 }
158158 }
@@ -230,7 +230,7 @@ public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_var
230230 /// silently ignored).
231231 /// </summary>
232232 /// <param name="var_list"></param>
233- protected virtual void _create_slots ( ResourceVariable [ ] var_list )
233+ protected virtual void _create_slots ( IVariableV1 [ ] var_list )
234234 {
235235
236236 }
@@ -369,8 +369,8 @@ protected IVariableV1 _get_non_slot_variable(string name, Graph graph = null)
369369 /// A list of (gradient, variable) pairs. Variable is always present, but
370370 /// gradient can be `None`.
371371 /// </returns>
372- public Tuple < Tensor , ResourceVariable > [ ] compute_gradients ( Tensor loss ,
373- List < ResourceVariable > var_list = null ,
372+ public Tuple < Tensor , IVariableV1 > [ ] compute_gradients ( Tensor loss ,
373+ List < IVariableV1 > var_list = null ,
374374 int ? aggregation_method = null ,
375375 GateGradientType gate_gradients = GateGradientType . GATE_OP ,
376376 bool colocate_gradients_with_ops = false ,
@@ -381,26 +381,13 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
381381
382382 if ( var_list == null )
383383 {
384- var vars = ops . get_collection < ResourceVariable > ( tf . GraphKeys . TRAINABLE_RESOURCE_VARIABLES ) ;
384+ var vars = ops . get_collection < IVariableV1 > ( tf . GraphKeys . TRAINABLE_RESOURCE_VARIABLES ) ;
385385 var tmp = variables . trainable_variables ( ) ;
386- switch ( tmp )
387- {
388- case List < ResourceVariable > values :
389- var_list = values . Concat ( vars ) . ToList ( ) ;
390- break ;
391- /*case List<RefVariable> values:
392- var_list = values.Concat(vars).ToList();
393- break;
394- case List<IVariableV1> values:
395- var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
396- break;*/
397- default :
398- throw new NotImplementedException ( "" ) ;
399- }
386+ var_list = ( tmp as List < IVariableV1 > ) . Concat ( vars ) . ToList ( ) ;
400387 }
401388
402- var_list = var_list . Concat ( ops . get_collection < ResourceVariable > ( tf . GraphKeys . _STREAMING_MODEL_PORTS ) ) . ToList ( ) ;
403- var processors = var_list . Select ( v => optimizer . _get_processor ( v ) ) . ToList ( ) ;
389+ var_list = var_list . Concat ( ops . get_collection < IVariableV1 > ( tf . GraphKeys . _STREAMING_MODEL_PORTS ) ) . ToList ( ) ;
390+ var processors = var_list . Select ( v => optimizer . _get_processor ( v as ResourceVariable ) ) . ToList ( ) ;
404391 var var_refs = processors . Select ( x => x . target ( ) ) . ToArray ( ) ;
405392
406393 var grads = gradients_impl . gradients ( new Tensor [ ] { loss } , var_refs , grad_ys : grad_loss == null ? null : new Tensor [ ] { grad_loss } ,
@@ -412,7 +399,7 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
412399 grads = control_flow_ops . tuple ( grads ) ;
413400
414401 var grads_and_vars = zip ( grads , var_list )
415- . Select ( x => new Tuple < Tensor , ResourceVariable > ( x . Item1 , x . Item2 ) )
402+ . Select ( x => new Tuple < Tensor , IVariableV1 > ( x . Item1 , x . Item2 ) )
416403 . ToArray ( ) ;
417404
418405 return grads_and_vars ;
0 commit comments