@@ -85,6 +85,9 @@ def __init__(self, optimizer_name, lr, hparams):
8585 beta1 = hparams .optimizer_adam_beta1 ,
8686 beta2 = hparams .optimizer_adam_beta2 ,
8787 epsilon = hparams .optimizer_adam_epsilon )
88+ elif optimizer_name == "Adafactor" :
89+ self ._opt = AdafactorOptimizer (
90+ lr / 500.0 , epsilon = hparams .optimizer_adam_epsilon )
8891 else :
8992 self ._opt = tf .contrib .layers .OPTIMIZER_CLS_NAMES [optimizer_name ](lr )
9093
@@ -252,3 +255,158 @@ def get_variable_initializer(hparams):
252255 hparams .initializer_gain , mode = "fan_avg" , distribution = "uniform" )
253256 else :
254257 raise ValueError ("Unrecognized initializer: %s" % hparams .initializer )
258+
259+
260+ class AdafactorOptimizer (tf .train .Optimizer ):
261+ """Optimizer that implements the Adafactor algorithm.
262+
263+ Adafactor is similar to Adam, but seeks to reduce the memory
264+ requirements due to the moment estimates. The auxiliary memory
265+ requirements for an `AxB` weight matrix are `A+B` for Adafactor,
266+ versus `2AB` for Adam.
267+
268+ Adam is described in [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
269+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
270+
271+ The differences are as follows:
272+
273+ 1. No momentum - this removes the first-moment estimate.
274+ 2. For an AxB weight matrix, instead of keeping a full AxB second-moment
275+ estimate matrix, Adafactor keeps only the row and column means of that
276+ estimate matrix, and estimate the full second-moment estimate matrix
277+ from on the fly, based on the means.
278+ 3. Adafactor uses a variable decay rate for the second-moment estaimtes -
279+ faster decay at the start of training and slower decay later. This
280+ elimnates the awkwardness in Adam related to having biased moment
281+ estimates at the start of training.
282+
283+ For non-2d variables:
284+ We initialize
285+ ```
286+ t <- 0
287+ v <- zeros(shape(var))
288+ ```
289+
290+ The update rule is as follows:
291+ ```
292+ t <- t + 1
293+ decay_horizon = min(t, t * relative_decay_horizon + absolute_decay_horizon)
294+ decay_rate = 1 - 1 / decay_horizon
295+ v <- decay_rate * v + (1 - decay_rate) * grad^2
296+ var <- var - lr * grad / (sqrt(v) + epsilon)
297+ ```
298+
299+ For 2d variables:
300+ We initialize
301+ ```
302+ t <- 0
303+ v_r <- zeros([num_rows])
304+ v_c <- zeros([num_cols])
305+ ```
306+
307+ The update rule is as follows:
308+ ```
309+ t <- t + 1
310+ decay_horizon = min(t, t * relative_decay_horizon + absolute_decay_horizon)
311+ decay_rate = 1 - 1 / decay_horizon
312+ v_r <- decay_rate * v_r + (1 - decay_rate) * reduce_mean(grad^2, 1)
313+ v_c <- decay_rate * v_c + (1 - decay_rate) * reduce_mean(grad^2, 0)
314+ approx_v = expand_dims(v_r, 1) * expand_dims(v_c, 0) / reduce_mean(v_c)
315+ var <- var - lr * grad / (sqrt(approx_v) + epsilon)
316+ ```
317+
318+ TODO(noam): write a paper.
319+ TODO(noam): we should also apply the 2d logic to the two final dimensions.
320+ of >2d convolutional kernels.
321+ """
322+
323+ def __init__ (self ,
324+ learning_rate = 0.001 ,
325+ epsilon = 1e-8 ,
326+ relative_decay_horizon = 0.2 ,
327+ absolute_decay_horizon = 100.0 ,
328+ use_locking = False ,
329+ name = "Adafactor" ):
330+ """Construct a new Adafactor optimizer.
331+
332+ See class comment.
333+
334+ Args:
335+ learning_rate: A Tensor or a floating point value. The learning rate.
336+ epsilon: A small constant for numerical stability.
337+ relative_decay_horizon: a floating point value <= 1
338+ absolute_decay_horizon: a floating point value (representing a step count)
339+ use_locking: If True use locks for update operations.
340+ name: Optional name for the operations created when applying gradients.
341+ Defaults to "AdafactorOptimizer".
342+ """
343+ super (AdafactorOptimizer , self ).__init__ (use_locking , name )
344+ self ._lr = learning_rate
345+ self ._relative_decay_horizon = relative_decay_horizon
346+ self ._absolute_decay_horizon = absolute_decay_horizon
347+ self ._epsilon = epsilon
348+
349+ def _prepare (self ):
350+ global_step = tf .to_float (tf .train .get_or_create_global_step ()) + 1.0
351+ decay_horizon = tf .minimum (global_step ,
352+ global_step * self ._relative_decay_horizon +
353+ self ._absolute_decay_horizon )
354+ self ._mixing_rate = 1.0 / decay_horizon
355+ self ._decay_rate = 1.0 - self ._mixing_rate
356+ self ._epsilon = tf .to_float (self ._epsilon )
357+ self ._lr = tf .to_float (self ._lr )
358+
359+ def _should_use_factored_second_moment_estimate (self , shape ):
360+ """Should we use a factored second moment estimator.
361+
362+ Based on the shape of the variable.
363+
364+ Args:
365+ shape: a list of integers
366+ Returns:
367+ a boolean
368+ """
369+ return len (shape ) == 2
370+
371+ def _create_slots (self , var_list ):
372+ for v in var_list :
373+ shape = v .get_shape ().as_list ()
374+ if self ._should_use_factored_second_moment_estimate (shape ):
375+ r_val = tf .zeros ([shape [0 ]], dtype = tf .float32 )
376+ c_val = tf .zeros ([shape [1 ]], dtype = tf .float32 )
377+ self ._get_or_make_slot (v , r_val , "vr" , self ._name )
378+ self ._get_or_make_slot (v , c_val , "vc" , self ._name )
379+ else :
380+ self ._zeros_slot (v , "v" , self ._name )
381+
382+ def _apply_dense (self , grad , var ):
383+ return self ._resource_apply_dense (grad , var )
384+
385+ def _resource_apply_dense (self , grad , var ):
386+ shape = var .get_shape ().as_list ()
387+ grad_squared = tf .square (grad )
388+ updates = []
389+ if self ._should_use_factored_second_moment_estimate (shape ):
390+ vr = self .get_slot (var , "vr" )
391+ new_vr = (self ._decay_rate * vr +
392+ self ._mixing_rate * tf .reduce_mean (grad_squared , 1 ))
393+ vc = self .get_slot (var , "vc" )
394+ new_vc = (self ._decay_rate * vc +
395+ self ._mixing_rate * tf .reduce_mean (grad_squared , 0 ))
396+ vr_update = tf .assign (vr , new_vr , use_locking = self ._use_locking )
397+ vc_update = tf .assign (vc , new_vc , use_locking = self ._use_locking )
398+ updates = [vr_update , vc_update ]
399+ vr = tf .sqrt (new_vr ) + self ._epsilon
400+ vc = tf .sqrt (new_vc ) + self ._epsilon
401+ vc /= tf .reduce_mean (vc )
402+ denom = tf .expand_dims (vr , 1 ) * tf .expand_dims (vc , 0 )
403+ else :
404+ v = self .get_slot (var , "v" )
405+ new_v = (self ._decay_rate * v + self ._mixing_rate * grad_squared )
406+ v_update = tf .assign (v , new_v , use_locking = self ._use_locking )
407+ updates = [v_update ]
408+ denom = tf .sqrt (new_v ) + self ._epsilon
409+ subtrahend = self ._lr * grad / denom
410+ var_update = tf .assign_sub (var , subtrahend , use_locking = self ._use_locking )
411+ updates = [var_update ] + updates
412+ return tf .group (* updates )
0 commit comments