Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6bc1efb

Browse files
nshazeerRyan Sepassi
authored andcommitted
New "Adafactor" optimizer - more memory efficient than Adam.
PiperOrigin-RevId: 181767655
1 parent 63135e0 commit 6bc1efb

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed

tensor2tensor/utils/optimize.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(self, optimizer_name, lr, hparams):
8585
beta1=hparams.optimizer_adam_beta1,
8686
beta2=hparams.optimizer_adam_beta2,
8787
epsilon=hparams.optimizer_adam_epsilon)
88+
elif optimizer_name == "Adafactor":
89+
self._opt = AdafactorOptimizer(
90+
lr / 500.0, epsilon=hparams.optimizer_adam_epsilon)
8891
else:
8992
self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr)
9093

@@ -252,3 +255,158 @@ def get_variable_initializer(hparams):
252255
hparams.initializer_gain, mode="fan_avg", distribution="uniform")
253256
else:
254257
raise ValueError("Unrecognized initializer: %s" % hparams.initializer)
258+
259+
260+
class AdafactorOptimizer(tf.train.Optimizer):
261+
"""Optimizer that implements the Adafactor algorithm.
262+
263+
Adafactor is similar to Adam, but seeks to reduce the memory
264+
requirements due to the moment estimates. The auxiliary memory
265+
requirements for an `AxB` weight matrix are `A+B` for Adafactor,
266+
versus `2AB` for Adam.
267+
268+
Adam is described in [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
269+
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
270+
271+
The differences are as follows:
272+
273+
1. No momentum - this removes the first-moment estimate.
274+
2. For an AxB weight matrix, instead of keeping a full AxB second-moment
275+
estimate matrix, Adafactor keeps only the row and column means of that
276+
estimate matrix, and estimate the full second-moment estimate matrix
277+
from on the fly, based on the means.
278+
3. Adafactor uses a variable decay rate for the second-moment estaimtes -
279+
faster decay at the start of training and slower decay later. This
280+
elimnates the awkwardness in Adam related to having biased moment
281+
estimates at the start of training.
282+
283+
For non-2d variables:
284+
We initialize
285+
```
286+
t <- 0
287+
v <- zeros(shape(var))
288+
```
289+
290+
The update rule is as follows:
291+
```
292+
t <- t + 1
293+
decay_horizon = min(t, t * relative_decay_horizon + absolute_decay_horizon)
294+
decay_rate = 1 - 1 / decay_horizon
295+
v <- decay_rate * v + (1 - decay_rate) * grad^2
296+
var <- var - lr * grad / (sqrt(v) + epsilon)
297+
```
298+
299+
For 2d variables:
300+
We initialize
301+
```
302+
t <- 0
303+
v_r <- zeros([num_rows])
304+
v_c <- zeros([num_cols])
305+
```
306+
307+
The update rule is as follows:
308+
```
309+
t <- t + 1
310+
decay_horizon = min(t, t * relative_decay_horizon + absolute_decay_horizon)
311+
decay_rate = 1 - 1 / decay_horizon
312+
v_r <- decay_rate * v_r + (1 - decay_rate) * reduce_mean(grad^2, 1)
313+
v_c <- decay_rate * v_c + (1 - decay_rate) * reduce_mean(grad^2, 0)
314+
approx_v = expand_dims(v_r, 1) * expand_dims(v_c, 0) / reduce_mean(v_c)
315+
var <- var - lr * grad / (sqrt(approx_v) + epsilon)
316+
```
317+
318+
TODO(noam): write a paper.
319+
TODO(noam): we should also apply the 2d logic to the two final dimensions.
320+
of >2d convolutional kernels.
321+
"""
322+
323+
def __init__(self,
324+
learning_rate=0.001,
325+
epsilon=1e-8,
326+
relative_decay_horizon=0.2,
327+
absolute_decay_horizon=100.0,
328+
use_locking=False,
329+
name="Adafactor"):
330+
"""Construct a new Adafactor optimizer.
331+
332+
See class comment.
333+
334+
Args:
335+
learning_rate: A Tensor or a floating point value. The learning rate.
336+
epsilon: A small constant for numerical stability.
337+
relative_decay_horizon: a floating point value <= 1
338+
absolute_decay_horizon: a floating point value (representing a step count)
339+
use_locking: If True use locks for update operations.
340+
name: Optional name for the operations created when applying gradients.
341+
Defaults to "AdafactorOptimizer".
342+
"""
343+
super(AdafactorOptimizer, self).__init__(use_locking, name)
344+
self._lr = learning_rate
345+
self._relative_decay_horizon = relative_decay_horizon
346+
self._absolute_decay_horizon = absolute_decay_horizon
347+
self._epsilon = epsilon
348+
349+
def _prepare(self):
350+
global_step = tf.to_float(tf.train.get_or_create_global_step()) + 1.0
351+
decay_horizon = tf.minimum(global_step,
352+
global_step * self._relative_decay_horizon +
353+
self._absolute_decay_horizon)
354+
self._mixing_rate = 1.0 / decay_horizon
355+
self._decay_rate = 1.0 - self._mixing_rate
356+
self._epsilon = tf.to_float(self._epsilon)
357+
self._lr = tf.to_float(self._lr)
358+
359+
def _should_use_factored_second_moment_estimate(self, shape):
360+
"""Should we use a factored second moment estimator.
361+
362+
Based on the shape of the variable.
363+
364+
Args:
365+
shape: a list of integers
366+
Returns:
367+
a boolean
368+
"""
369+
return len(shape) == 2
370+
371+
def _create_slots(self, var_list):
372+
for v in var_list:
373+
shape = v.get_shape().as_list()
374+
if self._should_use_factored_second_moment_estimate(shape):
375+
r_val = tf.zeros([shape[0]], dtype=tf.float32)
376+
c_val = tf.zeros([shape[1]], dtype=tf.float32)
377+
self._get_or_make_slot(v, r_val, "vr", self._name)
378+
self._get_or_make_slot(v, c_val, "vc", self._name)
379+
else:
380+
self._zeros_slot(v, "v", self._name)
381+
382+
def _apply_dense(self, grad, var):
383+
return self._resource_apply_dense(grad, var)
384+
385+
def _resource_apply_dense(self, grad, var):
386+
shape = var.get_shape().as_list()
387+
grad_squared = tf.square(grad)
388+
updates = []
389+
if self._should_use_factored_second_moment_estimate(shape):
390+
vr = self.get_slot(var, "vr")
391+
new_vr = (self._decay_rate * vr +
392+
self._mixing_rate * tf.reduce_mean(grad_squared, 1))
393+
vc = self.get_slot(var, "vc")
394+
new_vc = (self._decay_rate * vc +
395+
self._mixing_rate * tf.reduce_mean(grad_squared, 0))
396+
vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
397+
vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
398+
updates = [vr_update, vc_update]
399+
vr = tf.sqrt(new_vr) + self._epsilon
400+
vc = tf.sqrt(new_vc) + self._epsilon
401+
vc /= tf.reduce_mean(vc)
402+
denom = tf.expand_dims(vr, 1) * tf.expand_dims(vc, 0)
403+
else:
404+
v = self.get_slot(var, "v")
405+
new_v = (self._decay_rate * v + self._mixing_rate * grad_squared)
406+
v_update = tf.assign(v, new_v, use_locking=self._use_locking)
407+
updates = [v_update]
408+
denom = tf.sqrt(new_v) + self._epsilon
409+
subtrahend = self._lr * grad / denom
410+
var_update = tf.assign_sub(var, subtrahend, use_locking=self._use_locking)
411+
updates = [var_update] + updates
412+
return tf.group(*updates)

0 commit comments

Comments
 (0)