1+ """ Adafactor Optimizer
2+
3+ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4+
5+ Original header/copyright below.
6+
7+ """
8+ # Copyright (c) Facebook, Inc. and its affiliates.
9+ #
10+ # This source code is licensed under the MIT license found in the
11+ # LICENSE file in the root directory of this source tree.
12+ import torch
13+ import math
14+
15+
16+ class Adafactor (torch .optim .Optimizer ):
17+ """Implements Adafactor algorithm.
18+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19+ (see https://arxiv.org/abs/1804.04235)
20+
21+ Note that this optimizer internally adjusts the learning rate depending on the
22+ *scale_parameter*, *relative_step* and *warmup_init* options.
23+
24+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25+ `relative_step=False`.
26+
27+ Arguments:
28+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29+ lr (float, optional): external learning rate (default: None)
30+ eps (tuple[float, float]): regularization constants for square gradient
31+ and parameter scale respectively (default: (1e-30, 1e-3))
32+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
35+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37+ relative_step (bool): if True, time-dependent learning rate is computed
38+ instead of external learning rate (default: True)
39+ warmup_init (bool): time-dependent learning rate computation depends on
40+ whether warm-up initialization is being used (default: False)
41+ """
42+
43+ def __init__ (self , params , lr = None , eps = 1e-30 , eps_scale = 1e-3 , clip_threshold = 1.0 ,
44+ decay_rate = - 0.8 , betas = None , weight_decay = 0.0 , scale_parameter = True , warmup_init = False ):
45+ relative_step = lr is None
46+ if warmup_init and not relative_step :
47+ raise ValueError ('warmup_init requires relative_step=True' )
48+
49+ beta1 = None if betas is None else betas [0 ] # make it compat with standard betas arg
50+ defaults = dict (lr = lr , eps = eps , eps_scale = eps_scale , clip_threshold = clip_threshold , decay_rate = decay_rate ,
51+ beta1 = beta1 , weight_decay = weight_decay , scale_parameter = scale_parameter ,
52+ relative_step = relative_step , warmup_init = warmup_init )
53+ super (Adafactor , self ).__init__ (params , defaults )
54+
55+ @staticmethod
56+ def _get_lr (param_group , param_state ):
57+ if param_group ['relative_step' ]:
58+ min_step = 1e-6 * param_state ['step' ] if param_group ['warmup_init' ] else 1e-2
59+ lr_t = min (min_step , 1.0 / math .sqrt (param_state ['step' ]))
60+ param_scale = 1.0
61+ if param_group ['scale_parameter' ]:
62+ param_scale = max (param_group ['eps_scale' ], param_state ['RMS' ])
63+ param_group ['lr' ] = lr_t * param_scale
64+ return param_group ['lr' ]
65+
66+ @staticmethod
67+ def _get_options (param_group , param_shape ):
68+ factored = len (param_shape ) >= 2
69+ use_first_moment = param_group ['beta1' ] is not None
70+ return factored , use_first_moment
71+
72+ @staticmethod
73+ def _rms (tensor ):
74+ return tensor .norm (2 ) / (tensor .numel () ** 0.5 )
75+
76+ def _approx_sq_grad (self , exp_avg_sq_row , exp_avg_sq_col ):
77+ r_factor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
78+ c_factor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
79+ return torch .mul (r_factor , c_factor )
80+
81+ def step (self , closure = None ):
82+ """Performs a single optimization step.
83+ Arguments:
84+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
85+ """
86+ loss = None
87+ if closure is not None :
88+ loss = closure ()
89+
90+ for group in self .param_groups :
91+ for p in group ['params' ]:
92+ if p .grad is None :
93+ continue
94+ grad = p .grad .data
95+ if grad .dtype in {torch .float16 , torch .bfloat16 }:
96+ grad = grad .float ()
97+ if grad .is_sparse :
98+ raise RuntimeError ('Adafactor does not support sparse gradients.' )
99+
100+ state = self .state [p ]
101+ grad_shape = grad .shape
102+
103+ factored , use_first_moment = self ._get_options (group , grad_shape )
104+ # State Initialization
105+ if len (state ) == 0 :
106+ state ['step' ] = 0
107+
108+ if use_first_moment :
109+ # Exponential moving average of gradient values
110+ state ['exp_avg' ] = torch .zeros_like (grad )
111+ if factored :
112+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]).to (grad )
113+ state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]).to (grad )
114+ else :
115+ state ['exp_avg_sq' ] = torch .zeros_like (grad )
116+
117+ state ['RMS' ] = 0
118+ else :
119+ if use_first_moment :
120+ state ['exp_avg' ] = state ['exp_avg' ].to (grad )
121+ if factored :
122+ state ['exp_avg_sq_row' ] = state ['exp_avg_sq_row' ].to (grad )
123+ state ['exp_avg_sq_col' ] = state ['exp_avg_sq_col' ].to (grad )
124+ else :
125+ state ['exp_avg_sq' ] = state ['exp_avg_sq' ].to (grad )
126+
127+ p_data_fp32 = p .data
128+ if p .data .dtype in {torch .float16 , torch .bfloat16 }:
129+ p_data_fp32 = p_data_fp32 .float ()
130+
131+ state ['step' ] += 1
132+ state ['RMS' ] = self ._rms (p_data_fp32 )
133+ lr_t = self ._get_lr (group , state )
134+
135+ beta2t = 1.0 - math .pow (state ['step' ], group ['decay_rate' ])
136+ update = grad ** 2 + group ['eps' ]
137+ if factored :
138+ exp_avg_sq_row = state ['exp_avg_sq_row' ]
139+ exp_avg_sq_col = state ['exp_avg_sq_col' ]
140+
141+ exp_avg_sq_row .mul_ (beta2t ).add_ (1.0 - beta2t , update .mean (dim = - 1 ))
142+ exp_avg_sq_col .mul_ (beta2t ).add_ (1.0 - beta2t , update .mean (dim = - 2 ))
143+ #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
144+ #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
145+
146+ # Approximation of exponential moving average of square of gradient
147+ update = self ._approx_sq_grad (exp_avg_sq_row , exp_avg_sq_col )
148+ update .mul_ (grad )
149+ else :
150+ exp_avg_sq = state ['exp_avg_sq' ]
151+
152+ exp_avg_sq .mul_ (beta2t ).add_ (1.0 - beta2t , update )
153+ #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
154+ update = exp_avg_sq .rsqrt ().mul_ (grad )
155+
156+ update .div_ ((self ._rms (update ) / group ['clip_threshold' ]).clamp_ (min = 1.0 ))
157+ update .mul_ (lr_t )
158+
159+ if use_first_moment :
160+ exp_avg = state ['exp_avg' ]
161+ exp_avg .mul_ (group ["beta1" ]).add_ (1 - group ["beta1" ], update )
162+ #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
163+ update = exp_avg
164+
165+ if group ['weight_decay' ] != 0 :
166+ p_data_fp32 .add_ (- group ["weight_decay" ] * lr_t , p_data_fp32 )
167+ #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
168+
169+ p_data_fp32 .add_ (- update )
170+
171+ if p .data .dtype in {torch .float16 , torch .bfloat16 }:
172+ p .data .copy_ (p_data_fp32 )
173+
174+ return loss
0 commit comments