2020# limitations under the License.
2121
2222import math
23- from typing import List , Tuple
23+ from typing import List , Optional , Tuple
2424
2525import torch
2626from torch import Tensor
@@ -56,6 +56,7 @@ class Adan(Optimizer):
5656 eps: Term added to the denominator to improve numerical stability.
5757 weight_decay: Decoupled weight decay (L2 penalty)
5858 no_prox: How to perform the weight decay
59+ caution: Enable caution from 'Cautious Optimizers'
5960 foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
6061 """
6162
@@ -66,7 +67,8 @@ def __init__(self,
6667 eps : float = 1e-8 ,
6768 weight_decay : float = 0.0 ,
6869 no_prox : bool = False ,
69- foreach : bool = True ,
70+ caution : bool = False ,
71+ foreach : Optional [bool ] = None ,
7072 ):
7173 if not 0.0 <= lr :
7274 raise ValueError ('Invalid learning rate: {}' .format (lr ))
@@ -85,6 +87,7 @@ def __init__(self,
8587 eps = eps ,
8688 weight_decay = weight_decay ,
8789 no_prox = no_prox ,
90+ caution = caution ,
8891 foreach = foreach ,
8992 )
9093 super ().__init__ (params , defaults )
@@ -93,6 +96,7 @@ def __setstate__(self, state):
9396 super (Adan , self ).__setstate__ (state )
9497 for group in self .param_groups :
9598 group .setdefault ('no_prox' , False )
99+ group .setdefault ('caution' , False )
96100
97101 @torch .no_grad ()
98102 def restart_opt (self ):
@@ -118,6 +122,11 @@ def step(self, closure=None):
118122 with torch .enable_grad ():
119123 loss = closure ()
120124
125+ try :
126+ has_scalar_maximum = 'Scalar' in torch .ops .aten ._foreach_maximum_ .overloads ()
127+ except :
128+ has_scalar_maximum = False
129+
121130 for group in self .param_groups :
122131 params_with_grad = []
123132 grads = []
@@ -161,9 +170,19 @@ def step(self, closure=None):
161170 if not params_with_grad :
162171 continue
163172
164- kwargs = dict (
165- params = params_with_grad ,
166- grads = grads ,
173+ if group ['foreach' ] is None :
174+ use_foreach = not group ['caution' ] or has_scalar_maximum
175+ else :
176+ use_foreach = group ['foreach' ]
177+
178+ if use_foreach :
179+ func = _multi_tensor_adan
180+ else :
181+ func = _single_tensor_adan
182+
183+ func (
184+ params_with_grad ,
185+ grads ,
167186 exp_avgs = exp_avgs ,
168187 exp_avg_sqs = exp_avg_sqs ,
169188 exp_avg_diffs = exp_avg_diffs ,
@@ -178,13 +197,9 @@ def step(self, closure=None):
178197 weight_decay = group ['weight_decay' ],
179198 eps = group ['eps' ],
180199 no_prox = group ['no_prox' ],
200+ caution = group ['caution' ],
181201 )
182202
183- if group ['foreach' ]:
184- _multi_tensor_adan (** kwargs )
185- else :
186- _single_tensor_adan (** kwargs )
187-
188203 return loss
189204
190205
@@ -206,6 +221,7 @@ def _single_tensor_adan(
206221 weight_decay : float ,
207222 eps : float ,
208223 no_prox : bool ,
224+ caution : bool ,
209225):
210226 for i , param in enumerate (params ):
211227 grad = grads [i ]
@@ -227,6 +243,12 @@ def _single_tensor_adan(
227243 step_size_diff = lr * beta2 / bias_correction2
228244 step_size = lr / bias_correction1
229245
246+ if caution :
247+ # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
248+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
249+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
250+ exp_avg = exp_avg * mask
251+
230252 if no_prox :
231253 param .mul_ (1 - lr * weight_decay )
232254 param .addcdiv_ (exp_avg , denom , value = - step_size )
@@ -257,6 +279,7 @@ def _multi_tensor_adan(
257279 weight_decay : float ,
258280 eps : float ,
259281 no_prox : bool ,
282+ caution : bool ,
260283):
261284 if len (params ) == 0 :
262285 return
@@ -282,6 +305,15 @@ def _multi_tensor_adan(
282305 step_size_diff = lr * beta2 / bias_correction2
283306 step_size = lr / bias_correction1
284307
308+ if caution :
309+ # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
310+ masks = torch ._foreach_mul (exp_avgs , grads )
311+ masks = [(m > 0 ).to (g .dtype ) for m , g in zip (masks , grads )]
312+ mask_scale = [m .mean () for m in masks ]
313+ torch ._foreach_maximum_ (mask_scale , 1e-3 )
314+ torch ._foreach_div_ (masks , mask_scale )
315+ exp_avgs = torch ._foreach_mul (exp_avgs , masks )
316+
285317 if no_prox :
286318 torch ._foreach_mul_ (params , 1 - lr * weight_decay )
287319 torch ._foreach_addcdiv_ (params , exp_avgs , denom , value = - step_size )
0 commit comments