@@ -18,22 +18,25 @@ public class OptimizerV2 : Trackable, IOptimizer
1818 protected bool _hypers_created ;
1919 protected virtual string _name { get ; }
2020
21- ResourceVariable _iterations ;
22- List < ResourceVariable > _weight ;
21+ IVariableV1 _iterations ;
22+ protected ResourceVariable iterations => _iterations as ResourceVariable ;
23+ List < IVariableV1 > _weights ;
2324 Dictionary < string , float > _hyper ;
24- Dictionary < string , ResourceVariable > _hyper_variables ;
25+ Dictionary < string , IVariableV1 > _hyper_variables ;
2526 protected bool _momentum ;
2627 protected float _initial_decay = 0.0f ;
2728 protected bool _use_locking = true ;
2829
29- Dictionary < DeviceDType , Dictionary < string , Tensor > > apply_state ;
30+ Dictionary < string , Dictionary < string , IVariableV1 > > _slots ;
31+ List < string > _slot_names ;
3032
3133 public OptimizerV2 ( ) : base ( )
3234 {
33- _weight = new List < ResourceVariable > ( ) ;
35+ _weights = new List < IVariableV1 > ( ) ;
3436 _hyper = new Dictionary < string , float > ( ) ;
35- _hyper_variables = new Dictionary < string , ResourceVariable > ( ) ;
36- apply_state = new Dictionary < DeviceDType , Dictionary < string , Tensor > > ( ) ;
37+ _hyper_variables = new Dictionary < string , IVariableV1 > ( ) ;
38+ _slots = new Dictionary < string , Dictionary < string , IVariableV1 > > ( ) ;
39+ _slot_names = new List < string > ( ) ;
3740 }
3841
3942 public void apply_gradients ( ( Tensor , ResourceVariable ) grads_and_vars ,
@@ -61,7 +64,7 @@ public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_va
6164 if ( grads_and_vars == null || grads_and_vars . Count ( ) == 0 )
6265 return control_flow_ops . no_op ( ) ;
6366
64- apply_state = _prepare ( var_list ) ;
67+ var apply_state = _prepare ( var_list ) ;
6568 if ( experimental_aggregate_gradients )
6669 {
6770 // var reduced_grads = _aggregate_gradients(grads_and_vars);
@@ -72,13 +75,13 @@ public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_va
7275 } ) ;
7376 }
7477
75- void apply_grad_to_update_var ( ResourceVariable var , EagerTensor grad )
78+ void apply_grad_to_update_var ( ResourceVariable var , Tensor grad , Dictionary < DeviceDType , Dictionary < string , Tensor > > apply_state )
7679 {
7780 _resource_apply_dense ( var , grad , apply_state ) ;
7881 }
7982
8083 protected virtual Operation _resource_apply_dense ( IVariableV1 var ,
81- EagerTensor grad ,
84+ Tensor grad ,
8285 Dictionary < DeviceDType , Dictionary < string , Tensor > > _apply_state )
8386 {
8487 throw new NotImplementedException ( "_resource_apply_dense" ) ;
@@ -94,7 +97,7 @@ void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
9497 {
9598 tf_with ( ops . name_scope ( "update" ) , delegate
9699 {
97- apply_grad_to_update_var ( var , grad as EagerTensor ) ;
100+ apply_grad_to_update_var ( var , grad , _apply_state ) ;
98101 } ) ;
99102 }
100103
@@ -107,6 +110,12 @@ Tensor[] _aggregate_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_
107110 return grads_and_vars . Select ( x => x . Item1 ) . ToArray ( ) ;
108111 }
109112
113+ protected IVariableV1 get_slot ( IVariableV1 var , string slot_name )
114+ {
115+ var slot_dict = _slots [ var . UniqueId ] ;
116+ return slot_dict [ slot_name ] ;
117+ }
118+
110119 Dictionary < DeviceDType , Dictionary < string , Tensor > > _prepare ( IVariableV1 [ ] var_list )
111120 {
112121 var _apply_state = new Dictionary < DeviceDType , Dictionary < string , Tensor > > ( ) ;
@@ -125,6 +134,11 @@ Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_l
125134 return _apply_state ;
126135 }
127136
137+ protected Dictionary < string , Tensor > _fallback_apply_state ( string var_device , TF_DataType var_dtype )
138+ {
139+ throw new NotImplementedException ( "" ) ;
140+ }
141+
128142 protected virtual void _prepare_local ( DeviceDType device_dtype ,
129143 Dictionary < DeviceDType , Dictionary < string , Tensor > > _apply_state )
130144 {
@@ -145,7 +159,7 @@ Tensor _decayed_lr(TF_DataType var_dtype)
145159 return lr_t ;
146160 }
147161
148- protected ResourceVariable _get_hyper ( string name , TF_DataType dtype = TF_DataType . DtInvalid )
162+ protected Tensor _get_hyper ( string name , TF_DataType dtype = TF_DataType . DtInvalid )
149163 {
150164 var value = _hyper_variables [ name ] ;
151165 return math_ops . cast ( value , dtype ) ;
@@ -160,7 +174,7 @@ void _create_all_weights(IVariableV1[] var_list)
160174 dtype : TF_DataType . TF_INT64 ,
161175 trainable : false ,
162176 aggregation : VariableAggregation . OnlyFirstReplica ) ;
163- _weight . Add ( _iterations ) ;
177+ _weights . Add ( _iterations ) ;
164178 }
165179
166180 _create_hypers ( ) ;
@@ -190,7 +204,7 @@ void _create_hypers()
190204 _hypers_created = true ;
191205 }
192206
193- void _create_slots ( IVariableV1 [ ] var_list )
207+ protected virtual void _create_slots ( IVariableV1 [ ] var_list )
194208 {
195209 if ( _momentum )
196210 {
@@ -199,6 +213,35 @@ void _create_slots(IVariableV1[] var_list)
199213 }
200214 }
201215
216+ protected IVariableV1 add_slot ( IVariableV1 var , string slot_name , IInitializer initializer = null )
217+ {
218+ if ( initializer == null )
219+ initializer = tf . zeros_initializer ;
220+
221+ if ( ! _slot_names . Contains ( slot_name ) )
222+ _slot_names . append ( slot_name ) ;
223+
224+ if ( ! _slots . ContainsKey ( var . UniqueId ) )
225+ _slots [ var . UniqueId ] = new Dictionary < string , IVariableV1 > ( ) ;
226+ var slot_dict = _slots [ var . UniqueId ] ;
227+ if ( ! slot_dict . ContainsKey ( slot_name ) )
228+ {
229+ var weight = tf . Variable ( initializer ,
230+ dtype : var . dtype ,
231+ trainable : false ,
232+ shape : var . shape ,
233+ name : $ "{ var . Name } /{ slot_name } ") ;
234+
235+ slot_dict [ slot_name ] = weight ;
236+ _weights . append ( weight ) ;
237+ return weight ;
238+ }
239+ else
240+ {
241+ return slot_dict [ slot_name ] ;
242+ }
243+ }
244+
202245 ResourceVariable add_weight ( string name ,
203246 TensorShape shape ,
204247 TF_DataType dtype = TF_DataType . TF_FLOAT ,
0 commit comments