55
66Implementation adapted from https://github.com/sail-sg/Adan
77"""
8+ # Copyright 2022 Garena Online Private Limited
9+ #
10+ # Licensed under the Apache License, Version 2.0 (the "License");
11+ # you may not use this file except in compliance with the License.
12+ # You may obtain a copy of the License at
13+ #
14+ # http://www.apache.org/licenses/LICENSE-2.0
15+ #
16+ # Unless required by applicable law or agreed to in writing, software
17+ # distributed under the License is distributed on an "AS IS" BASIS,
18+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+ # See the License for the specific language governing permissions and
20+ # limitations under the License.
821
922import math
23+ from typing import List , Tuple
1024
1125import torch
26+ from torch import Tensor
27+ from torch .optim .optimizer import Optimizer
1228
13- from torch .optim import Optimizer
29+
30+ class MultiTensorApply (object ):
31+ available = False
32+ warned = False
33+
34+ def __init__ (self , chunk_size ):
35+ try :
36+ MultiTensorApply .available = True
37+ self .chunk_size = chunk_size
38+ except ImportError as err :
39+ MultiTensorApply .available = False
40+ MultiTensorApply .import_err = err
41+
42+ def __call__ (self , op , noop_flag_buffer , tensor_lists , * args ):
43+ return op (self .chunk_size , noop_flag_buffer , tensor_lists , * args )
1444
1545
1646class Adan (Optimizer ):
17- """
18- Implements a pytorch variant of Adan
19- Adan was proposed in
20- Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
47+ """ Implements a pytorch variant of Adan.
48+
49+ Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models
2150 https://arxiv.org/abs/2208.06677
51+
2252 Arguments:
23- params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
24- lr (float, optional): learning rate. (default: 1e-3)
25- betas (Tuple[float, float, flot], optional): coefficients used for computing
26- running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
27- eps (float, optional): term added to the denominator to improve
28- numerical stability. (default: 1e-8)
29- weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
30- no_prox (bool): how to perform the decoupled weight decay (default: False)
53+ params: Iterable of parameters to optimize or dicts defining parameter groups.
54+ lr: Learning rate.
55+ betas: Coefficients used for first- and second-order moments.
56+ eps: Term added to the denominator to improve numerical stability.
57+ weight_decay: Decoupled weight decay (L2 penalty)
58+ no_prox: How to perform the weight decay
59+ foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
3160 """
3261
33- def __init__ (
34- self ,
62+ def __init__ (self ,
3563 params ,
36- lr = 1e-3 ,
37- betas = (0.98 , 0.92 , 0.99 ),
38- eps = 1e-8 ,
39- weight_decay = 0.0 ,
40- no_prox = False ,
64+ lr : float = 1e-3 ,
65+ betas : Tuple [float , float , float ] = (0.98 , 0.92 , 0.99 ),
66+ eps : float = 1e-8 ,
67+ weight_decay : float = 0.0 ,
68+ no_prox : bool = False ,
69+ foreach : bool = True ,
4170 ):
4271 if not 0.0 <= lr :
43- raise ValueError (" Invalid learning rate: {}" .format (lr ))
72+ raise ValueError (' Invalid learning rate: {}' .format (lr ))
4473 if not 0.0 <= eps :
45- raise ValueError (" Invalid epsilon value: {}" .format (eps ))
74+ raise ValueError (' Invalid epsilon value: {}' .format (eps ))
4675 if not 0.0 <= betas [0 ] < 1.0 :
47- raise ValueError (" Invalid beta parameter at index 0: {}" .format (betas [0 ]))
76+ raise ValueError (' Invalid beta parameter at index 0: {}' .format (betas [0 ]))
4877 if not 0.0 <= betas [1 ] < 1.0 :
49- raise ValueError (" Invalid beta parameter at index 1: {}" .format (betas [1 ]))
78+ raise ValueError (' Invalid beta parameter at index 1: {}' .format (betas [1 ]))
5079 if not 0.0 <= betas [2 ] < 1.0 :
51- raise ValueError ("Invalid beta parameter at index 2: {}" .format (betas [2 ]))
52- defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , no_prox = no_prox )
53- super (Adan , self ).__init__ (params , defaults )
80+ raise ValueError ('Invalid beta parameter at index 2: {}' .format (betas [2 ]))
81+
82+ defaults = dict (
83+ lr = lr ,
84+ betas = betas ,
85+ eps = eps ,
86+ weight_decay = weight_decay ,
87+ no_prox = no_prox ,
88+ foreach = foreach ,
89+ )
90+ super ().__init__ (params , defaults )
91+
92+ def __setstate__ (self , state ):
93+ super (Adan , self ).__setstate__ (state )
94+ for group in self .param_groups :
95+ group .setdefault ('no_prox' , False )
5496
5597 @torch .no_grad ()
5698 def restart_opt (self ):
@@ -70,17 +112,23 @@ def restart_opt(self):
70112
71113 @torch .no_grad ()
72114 def step (self , closure = None ):
73- """ Performs a single optimization step.
74- """
115+ """Performs a single optimization step."""
75116 loss = None
76117 if closure is not None :
77118 with torch .enable_grad ():
78119 loss = closure ()
79120
80121 for group in self .param_groups :
122+ params_with_grad = []
123+ grads = []
124+ exp_avgs = []
125+ exp_avg_sqs = []
126+ exp_avg_diffs = []
127+ neg_pre_grads = []
128+
81129 beta1 , beta2 , beta3 = group ['betas' ]
82130 # assume same step across group now to simplify things
83- # per parameter step can be easily support by making it tensor, or pass list into kernel
131+ # per parameter step can be easily supported by making it a tensor, or pass list into kernel
84132 if 'step' in group :
85133 group ['step' ] += 1
86134 else :
@@ -93,32 +141,155 @@ def step(self, closure=None):
93141 for p in group ['params' ]:
94142 if p .grad is None :
95143 continue
96- grad = p .grad
144+ params_with_grad .append (p )
145+ grads .append (p .grad )
97146
98147 state = self .state [p ]
99148 if len (state ) == 0 :
100149 state ['exp_avg' ] = torch .zeros_like (p )
101- state ['exp_avg_diff' ] = torch .zeros_like (p )
102150 state ['exp_avg_sq' ] = torch .zeros_like (p )
103- state ['pre_grad ' ] = grad . clone ( )
151+ state ['exp_avg_diff ' ] = torch . zeros_like ( p )
104152
105- exp_avg , exp_avg_sq , exp_avg_diff = state [ 'exp_avg' ], state [ 'exp_avg_diff' ], state [ 'exp_avg_sq' ]
106- grad_diff = grad - state ['pre_grad' ]
153+ if 'neg_pre_grad' not in state or group [ 'step' ] == 1 :
154+ state ['neg_pre_grad' ] = - p . grad . clone ()
107155
108- exp_avg . lerp_ ( grad , 1. - beta1 ) # m_t
109- exp_avg_diff . lerp_ ( grad_diff , 1. - beta2 ) # diff_t (v )
110- update = grad + beta2 * grad_diff
111- exp_avg_sq . mul_ ( beta3 ). addcmul_ ( update , update , value = 1. - beta3 ) # n_t
156+ exp_avgs . append ( state [ 'exp_avg' ])
157+ exp_avg_sqs . append ( state [ 'exp_avg_sq' ] )
158+ exp_avg_diffs . append ( state [ 'exp_avg_diff' ])
159+ neg_pre_grads . append ( state [ 'neg_pre_grad' ])
112160
113- denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction3 )).add_ (group ['eps' ])
114- update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2 ).div_ (denom )
115- if group ['no_prox' ]:
116- p .data .mul_ (1 - group ['lr' ] * group ['weight_decay' ])
117- p .add_ (update , alpha = - group ['lr' ])
118- else :
119- p .add_ (update , alpha = - group ['lr' ])
120- p .data .div_ (1 + group ['lr' ] * group ['weight_decay' ])
161+ if not params_with_grad :
162+ continue
121163
122- state ['pre_grad' ].copy_ (grad )
164+ kwargs = dict (
165+ params = params_with_grad ,
166+ grads = grads ,
167+ exp_avgs = exp_avgs ,
168+ exp_avg_sqs = exp_avg_sqs ,
169+ exp_avg_diffs = exp_avg_diffs ,
170+ neg_pre_grads = neg_pre_grads ,
171+ beta1 = beta1 ,
172+ beta2 = beta2 ,
173+ beta3 = beta3 ,
174+ bias_correction1 = bias_correction1 ,
175+ bias_correction2 = bias_correction2 ,
176+ bias_correction3_sqrt = math .sqrt (bias_correction3 ),
177+ lr = group ['lr' ],
178+ weight_decay = group ['weight_decay' ],
179+ eps = group ['eps' ],
180+ no_prox = group ['no_prox' ],
181+ )
182+
183+ if group ['foreach' ]:
184+ _multi_tensor_adan (** kwargs )
185+ else :
186+ _single_tensor_adan (** kwargs )
123187
124188 return loss
189+
190+
191+ def _single_tensor_adan (
192+ params : List [Tensor ],
193+ grads : List [Tensor ],
194+ exp_avgs : List [Tensor ],
195+ exp_avg_sqs : List [Tensor ],
196+ exp_avg_diffs : List [Tensor ],
197+ neg_pre_grads : List [Tensor ],
198+ * ,
199+ beta1 : float ,
200+ beta2 : float ,
201+ beta3 : float ,
202+ bias_correction1 : float ,
203+ bias_correction2 : float ,
204+ bias_correction3_sqrt : float ,
205+ lr : float ,
206+ weight_decay : float ,
207+ eps : float ,
208+ no_prox : bool ,
209+ ):
210+ for i , param in enumerate (params ):
211+ grad = grads [i ]
212+ exp_avg = exp_avgs [i ]
213+ exp_avg_sq = exp_avg_sqs [i ]
214+ exp_avg_diff = exp_avg_diffs [i ]
215+ neg_grad_or_diff = neg_pre_grads [i ]
216+
217+ # for memory saving, we use `neg_grad_or_diff` to get some temp variable in an inplace way
218+ neg_grad_or_diff .add_ (grad )
219+
220+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 ) # m_t
221+ exp_avg_diff .mul_ (beta2 ).add_ (neg_grad_or_diff , alpha = 1 - beta2 ) # diff_t
222+
223+ neg_grad_or_diff .mul_ (beta2 ).add_ (grad )
224+ exp_avg_sq .mul_ (beta3 ).addcmul_ (neg_grad_or_diff , neg_grad_or_diff , value = 1 - beta3 ) # n_t
225+
226+ denom = (exp_avg_sq .sqrt () / bias_correction3_sqrt ).add_ (eps )
227+ step_size_diff = lr * beta2 / bias_correction2
228+ step_size = lr / bias_correction1
229+
230+ if no_prox :
231+ param .mul_ (1 - lr * weight_decay )
232+ param .addcdiv_ (exp_avg , denom , value = - step_size )
233+ param .addcdiv_ (exp_avg_diff , denom , value = - step_size_diff )
234+ else :
235+ param .addcdiv_ (exp_avg , denom , value = - step_size )
236+ param .addcdiv_ (exp_avg_diff , denom , value = - step_size_diff )
237+ param .div_ (1 + lr * weight_decay )
238+
239+ neg_grad_or_diff .zero_ ().add_ (grad , alpha = - 1.0 )
240+
241+
242+ def _multi_tensor_adan (
243+ params : List [Tensor ],
244+ grads : List [Tensor ],
245+ exp_avgs : List [Tensor ],
246+ exp_avg_sqs : List [Tensor ],
247+ exp_avg_diffs : List [Tensor ],
248+ neg_pre_grads : List [Tensor ],
249+ * ,
250+ beta1 : float ,
251+ beta2 : float ,
252+ beta3 : float ,
253+ bias_correction1 : float ,
254+ bias_correction2 : float ,
255+ bias_correction3_sqrt : float ,
256+ lr : float ,
257+ weight_decay : float ,
258+ eps : float ,
259+ no_prox : bool ,
260+ ):
261+ if len (params ) == 0 :
262+ return
263+
264+ # for memory saving, we use `neg_pre_grads` to get some temp variable in a inplace way
265+ torch ._foreach_add_ (neg_pre_grads , grads )
266+
267+ torch ._foreach_mul_ (exp_avgs , beta1 )
268+ torch ._foreach_add_ (exp_avgs , grads , alpha = 1 - beta1 ) # m_t
269+
270+ torch ._foreach_mul_ (exp_avg_diffs , beta2 )
271+ torch ._foreach_add_ (exp_avg_diffs , neg_pre_grads , alpha = 1 - beta2 ) # diff_t
272+
273+ torch ._foreach_mul_ (neg_pre_grads , beta2 )
274+ torch ._foreach_add_ (neg_pre_grads , grads )
275+ torch ._foreach_mul_ (exp_avg_sqs , beta3 )
276+ torch ._foreach_addcmul_ (exp_avg_sqs , neg_pre_grads , neg_pre_grads , value = 1 - beta3 ) # n_t
277+
278+ denom = torch ._foreach_sqrt (exp_avg_sqs )
279+ torch ._foreach_div_ (denom , bias_correction3_sqrt )
280+ torch ._foreach_add_ (denom , eps )
281+
282+ step_size_diff = lr * beta2 / bias_correction2
283+ step_size = lr / bias_correction1
284+
285+ if no_prox :
286+ torch ._foreach_mul_ (params , 1 - lr * weight_decay )
287+ torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
288+ torch ._foreach_addcdiv_ (params , exp_avg_diffs , denom , value = - step_size_diff )
289+ else :
290+ torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
291+ torch ._foreach_addcdiv_ (params , exp_avg_diffs , denom , value = - step_size_diff )
292+ torch ._foreach_div_ (params , 1 + lr * weight_decay )
293+
294+ torch ._foreach_zero_ (neg_pre_grads )
295+ torch ._foreach_add_ (neg_pre_grads , grads , alpha = - 1.0 )
0 commit comments