11# coding=utf-8
2- # Copyright 2019 The Tensor2Tensor Authors.
2+ # Copyright 2020 The Tensor2Tensor Authors.
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ # Copyright 2019 The Tensor2Tensor Authors.
17+ #
18+ # Licensed under the Apache License, Version 2.0 (the "License");
19+ # you may not use this file except in compliance with the License.
20+ # You may obtain a copy of the License at
21+ #
22+ # http://www.apache.org/licenses/LICENSE-2.0
23+ #
24+ # Unless required by applicable law or agreed to in writing, software
25+ # distributed under the License is distributed on an "AS IS" BASIS,
26+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+ # See the License for the specific language governing permissions and
28+ # limitations under the License.
1629"""Multi-step optimizers simulating large batches.
1730
1831Optimizer variants which make it possible to use very large batch sizes with
2639from __future__ import division
2740from __future__ import print_function
2841
29- import tensorflow as tf
30- from tensorflow .python .eager import context
31- from tensorflow .python .framework import dtypes
32- from tensorflow .python .framework import ops
33- from tensorflow .python .ops import control_flow_ops
34- from tensorflow .python .ops import math_ops
42+ import tensorflow .compat .v1 as tf
43+ # pylint: disable=g-direct-tensorflow-import
3544from tensorflow .python .ops import resource_variable_ops
36- from tensorflow .python .ops import state_ops
37- from tensorflow .python .training import optimizer
3845from tensorflow .python .training import training_ops
39- from tensorflow .python .util .tf_export import tf_export
40- from tensorflow .keras import backend as K
46+ # pylint: enable=g-direct-tensorflow-import
4147
4248
43- class MultistepAdamOptimizer (optimizer .Optimizer ):
49+ class MultistepAdamOptimizer (tf . train .Optimizer ):
4450 """Adam with SGD updates every n steps with accumulated gradients."""
4551
46- def __init__ (self , learning_rate = 0.001 , beta1 = 0.9 , beta2 = 0.999 , epsilon = 1e-8 ,
47- use_locking = False , name = "Adam" , n = 1 ):
48- super (MultistepAdamOptimizer , self ).__init__ (use_locking = use_locking , name = name )
52+ def __init__ (self ,
53+ learning_rate = 0.001 ,
54+ beta1 = 0.9 ,
55+ beta2 = 0.999 ,
56+ epsilon = 1e-8 ,
57+ use_locking = False ,
58+ name = "Adam" ,
59+ n = 1 ):
60+ super (MultistepAdamOptimizer , self ).__init__ (
61+ use_locking = use_locking , name = name )
4962 self ._lr = learning_rate
5063 self ._beta1 = beta1
5164 self ._beta2 = beta2
@@ -59,43 +72,46 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
5972 self ._n_t = None # n as tensor
6073
6174 def _get_beta_accumulators (self ):
62- with ops .init_scope ():
63- if context .executing_eagerly ():
75+ with tf .init_scope ():
76+ if tf .executing_eagerly ():
6477 graph = None
6578 else :
66- graph = ops .get_default_graph ()
79+ graph = tf .get_default_graph ()
6780 return (self ._get_non_slot_variable ("beta1_power" , graph = graph ),
6881 self ._get_non_slot_variable ("beta2_power" , graph = graph ))
6982
7083 def _create_slots (self , var_list ):
7184 """Create slot variables for Adam with accumulated gradients."""
7285 first_var = min (var_list , key = lambda x : x .name )
73- self ._create_non_slot_variable (initial_value = self ._beta1 , name = "beta1_power" , colocate_with = first_var )
74- self ._create_non_slot_variable (initial_value = self ._beta2 , name = "beta2_power" , colocate_with = first_var )
75- #if iter is initialized as an int32, this optimizer could not run
76- #with tensorflow_hub with a tensorflow-gpu version
77- self ._create_non_slot_variable (initial_value = 0.0 if self ._n == 1 else 1.0 , name = "iter" , colocate_with = first_var )
86+ self ._create_non_slot_variable (
87+ initial_value = self ._beta1 , name = "beta1_power" , colocate_with = first_var )
88+ self ._create_non_slot_variable (
89+ initial_value = self ._beta2 , name = "beta2_power" , colocate_with = first_var )
90+ # if iter is initialized as an int32, this optimizer could not run
91+ # with tensorflow_hub with a tensorflow-gpu version
92+ self ._create_non_slot_variable (
93+ initial_value = 0.0 if self ._n == 1 else 1.0 ,
94+ name = "iter" ,
95+ colocate_with = first_var )
7896 # Create slots for the first and second moments, as well as grad_acc.
7997 for v in var_list :
8098 self ._zeros_slot (v , "m" , self ._name )
8199 self ._zeros_slot (v , "v" , self ._name )
82100 self ._zeros_slot (v , "grad_acc" , self ._name )
83101
84-
85102 def _get_iter_variable (self ):
86- graph = (
87- None if tf .executing_eagerly () else tf .get_default_graph ())
103+ graph = (None if tf .executing_eagerly () else tf .get_default_graph ())
88104 return self ._get_non_slot_variable ("iter" , graph = graph )
89105
90106 def _prepare (self ):
91107 lr = self ._call_if_callable (self ._lr )
92108 beta1 = self ._call_if_callable (self ._beta1 )
93109 beta2 = self ._call_if_callable (self ._beta2 )
94110 epsilon = self ._call_if_callable (self ._epsilon )
95- self ._beta1_t = ops .convert_to_tensor (beta1 , name = "beta1" )
96- self ._beta2_t = ops .convert_to_tensor (beta2 , name = "beta2" )
97- self ._lr_t = ops .convert_to_tensor (lr , name = "learning_rate" )
98- self ._epsilon_t = ops .convert_to_tensor (epsilon , name = "epsilon" )
111+ self ._beta1_t = tf .convert_to_tensor (beta1 , name = "beta1" )
112+ self ._beta2_t = tf .convert_to_tensor (beta2 , name = "beta2" )
113+ self ._lr_t = tf .convert_to_tensor (lr , name = "learning_rate" )
114+ self ._epsilon_t = tf .convert_to_tensor (epsilon , name = "epsilon" )
99115 self ._n_t = tf .convert_to_tensor (self ._n , name = "n" )
100116
101117 def _apply_cond (self , apply_fn , grad , var , * args , ** kwargs ):
@@ -106,8 +122,8 @@ def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs):
106122 total_grad = (grad_acc + grad ) / tf .cast (self ._n_t , grad .dtype )
107123 adam_op = apply_fn (total_grad , var , * args , ** kwargs )
108124 with tf .control_dependencies ([adam_op ]):
109- grad_acc_to_zero_op = grad_acc .assign (tf . zeros_like ( grad_acc ),
110- use_locking = self ._use_locking )
125+ grad_acc_to_zero_op = grad_acc .assign (
126+ tf . zeros_like ( grad_acc ), use_locking = self ._use_locking )
111127 return tf .group (adam_op , grad_acc_to_zero_op )
112128
113129 def accumulate_gradient (grad_acc , grad ):
@@ -126,14 +142,17 @@ def _apply_dense_in_action(self, grad, var):
126142 m = self .get_slot (var , "m" )
127143 v = self .get_slot (var , "v" )
128144 beta1_power , beta2_power = self ._get_beta_accumulators ()
129- return training_ops .apply_adam (var , m , v ,
130- math_ops .cast (beta1_power , var .dtype .base_dtype ),
131- math_ops .cast (beta2_power , var .dtype .base_dtype ),
132- math_ops .cast (self ._lr_t , var .dtype .base_dtype ),
133- math_ops .cast (self ._beta1_t , var .dtype .base_dtype ),
134- math_ops .cast (self ._beta2_t , var .dtype .base_dtype ),
135- math_ops .cast (self ._epsilon_t , var .dtype .base_dtype ),
136- grad ,
145+ return training_ops .apply_adam (
146+ var ,
147+ m ,
148+ v ,
149+ tf .cast (beta1_power , var .dtype .base_dtype ),
150+ tf .cast (beta2_power , var .dtype .base_dtype ),
151+ tf .cast (self ._lr_t , var .dtype .base_dtype ),
152+ tf .cast (self ._beta1_t , var .dtype .base_dtype ),
153+ tf .cast (self ._beta2_t , var .dtype .base_dtype ),
154+ tf .cast (self ._epsilon_t , var .dtype .base_dtype ),
155+ grad ,
137156 use_locking = self ._use_locking ).op
138157
139158 def _resource_apply_dense (self , grad , var ):
@@ -143,41 +162,44 @@ def _resource_apply_dense_in_action(self, grad, var):
143162 m = self .get_slot (var , "m" )
144163 v = self .get_slot (var , "v" )
145164 beta1_power , beta2_power = self ._get_beta_accumulators ()
146- return training_ops .resource_apply_adam (var .handle ,
147- m .handle ,
165+ return training_ops .resource_apply_adam (
166+ var .handle ,
167+ m .handle ,
148168 v .handle ,
149- math_ops .cast (beta1_power , grad .dtype .base_dtype ),
150- math_ops .cast (beta2_power , grad .dtype .base_dtype ),
151- math_ops .cast (self ._lr_t , var .dtype .base_dtype ),
152- math_ops .cast (self ._beta1_t , grad .dtype .base_dtype ),
153- math_ops .cast (self ._beta2_t , grad .dtype .base_dtype ),
154- math_ops .cast (self ._epsilon_t , grad .dtype .base_dtype ),
155- grad , use_locking = self ._use_locking )
169+ tf .cast (beta1_power , grad .dtype .base_dtype ),
170+ tf .cast (beta2_power , grad .dtype .base_dtype ),
171+ tf .cast (self ._lr_t , var .dtype .base_dtype ),
172+ tf .cast (self ._beta1_t , grad .dtype .base_dtype ),
173+ tf .cast (self ._beta2_t , grad .dtype .base_dtype ),
174+ tf .cast (self ._epsilon_t , grad .dtype .base_dtype ),
175+ grad ,
176+ use_locking = self ._use_locking )
156177
157178 def _apply_sparse_shared (self , grad , var , indices , scatter_add ):
158179 beta1_power , beta2_power = self ._get_beta_accumulators ()
159- beta1_power = math_ops .cast (beta1_power , var .dtype .base_dtype )
160- beta2_power = math_ops .cast (beta2_power , var .dtype .base_dtype )
161- lr_t = math_ops .cast (self ._lr_t , var .dtype .base_dtype )
162- beta1_t = math_ops .cast (self ._beta1_t , var .dtype .base_dtype )
163- beta2_t = math_ops .cast (self ._beta2_t , var .dtype .base_dtype )
164- epsilon_t = math_ops .cast (self ._epsilon_t , var .dtype .base_dtype )
165- lr = (lr_t * math_ops .sqrt (1 - beta2_power ) / (1 - beta1_power ))
180+ beta1_power = tf .cast (beta1_power , var .dtype .base_dtype )
181+ beta2_power = tf .cast (beta2_power , var .dtype .base_dtype )
182+ lr_t = tf .cast (self ._lr_t , var .dtype .base_dtype )
183+ beta1_t = tf .cast (self ._beta1_t , var .dtype .base_dtype )
184+ beta2_t = tf .cast (self ._beta2_t , var .dtype .base_dtype )
185+ epsilon_t = tf .cast (self ._epsilon_t , var .dtype .base_dtype )
186+ lr = (lr_t * tf .sqrt (1 - beta2_power ) / (1 - beta1_power ))
166187 # m_t = beta1 * m + (1 - beta1) * g_t
167188 m = self .get_slot (var , "m" )
168189 m_scaled_g_values = grad * (1 - beta1_t )
169- m_t = state_ops .assign (m , m * beta1_t , use_locking = self ._use_locking )
170- with ops .control_dependencies ([m_t ]):
190+ m_t = tf .assign (m , m * beta1_t , use_locking = self ._use_locking )
191+ with tf .control_dependencies ([m_t ]):
171192 m_t = scatter_add (m , indices , m_scaled_g_values )
172193 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
173194 v = self .get_slot (var , "v" )
174195 v_scaled_g_values = (grad * grad ) * (1 - beta2_t )
175- v_t = state_ops .assign (v , v * beta2_t , use_locking = self ._use_locking )
176- with ops .control_dependencies ([v_t ]):
196+ v_t = tf .assign (v , v * beta2_t , use_locking = self ._use_locking )
197+ with tf .control_dependencies ([v_t ]):
177198 v_t = scatter_add (v , indices , v_scaled_g_values )
178- v_sqrt = math_ops .sqrt (v_t )
179- var_update = state_ops .assign_sub (var , lr * m_t / (v_sqrt + epsilon_t ), use_locking = self ._use_locking )
180- return control_flow_ops .group (* [var_update , m_t , v_t ])
199+ v_sqrt = tf .sqrt (v_t )
200+ var_update = tf .assign_sub (
201+ var , lr * m_t / (v_sqrt + epsilon_t ), use_locking = self ._use_locking )
202+ return tf .group (* [var_update , m_t , v_t ])
181203
182204 def _apply_sparse (self , grad , var ):
183205 # TODO(fstahlberg): Implement a sparse version
@@ -191,39 +213,44 @@ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
191213 # correctly (summing them). A real sparse implementation will probably want
192214 # to override _resource_apply_sparse instead so it gets them de-duplicated
193215 # automatically.
194- dense_grad = tf .convert_to_tensor (tf .IndexedSlices (values = grad ,
195- indices = indices , dense_shape = tf .shape (var )))
196- return self ._apply_cond (self ._resource_apply_dense_in_action , dense_grad , var )
216+ dense_grad = tf .convert_to_tensor (
217+ tf .IndexedSlices (
218+ values = grad , indices = indices , dense_shape = tf .shape (var )))
219+ return self ._apply_cond (self ._resource_apply_dense_in_action , dense_grad ,
220+ var )
197221
198222 def _resource_scatter_add (self , x , i , v ):
199- with ops .control_dependencies (
223+ with tf .control_dependencies (
200224 [resource_variable_ops .resource_scatter_add (x .handle , i , v )]):
201225 return x .value ()
202226
203227 def _resource_apply_sparse (self , grad , var , indices ):
204- return self ._apply_sparse_shared (grad , var , indices , self ._resource_scatter_add )
228+ return self ._apply_sparse_shared (grad , var , indices ,
229+ self ._resource_scatter_add )
205230
206231 def _finish (self , update_ops , name_scope ):
207232 """Updates beta_power variables every n batches and incrs counter."""
208233 iter_ = self ._get_iter_variable ()
209234 beta1_power , beta2_power = self ._get_beta_accumulators ()
210235 with tf .control_dependencies (update_ops ):
211236 with tf .colocate_with (iter_ ):
237+
212238 def update_beta_op ():
213239 update_beta1 = beta1_power .assign (
214- beta1_power * self ._beta1_t ,
215- use_locking = self ._use_locking )
240+ beta1_power * self ._beta1_t , use_locking = self ._use_locking )
216241 update_beta2 = beta2_power .assign (
217- beta2_power * self ._beta2_t ,
218- use_locking = self ._use_locking )
242+ beta2_power * self ._beta2_t , use_locking = self ._use_locking )
219243 return tf .group (update_beta1 , update_beta2 )
244+
220245 maybe_update_beta = tf .cond (
221246 tf .equal (iter_ , 0 ), update_beta_op , tf .no_op )
222247 with tf .control_dependencies ([maybe_update_beta ]):
223- #TODO(Cuong): It is suboptimal here because we have to cast twice (float to int,
224- #and then int to float)
225- update_iter = iter_ .assign (K .cast (tf .mod (K .cast (iter_ + 1.0 , dtype = dtypes .int32 ), self ._n_t ), dtype = dtypes .float32 ),
226- use_locking = self ._use_locking )
248+ # TODO(cuong): It is suboptimal here because we have to cast twice
249+ # (float to int, and then int to float)
250+ update_iter = iter_ .assign (
251+ tf .cast (
252+ tf .mod (tf .cast (iter_ + 1.0 , tf .int32 ), self ._n_t ),
253+ tf .float32 ),
254+ use_locking = self ._use_locking )
227255 return tf .group (
228256 * update_ops + [update_iter , maybe_update_beta ], name = name_scope )
229-
0 commit comments