1919from .adafactor_bv import AdafactorBigVision
2020from .adahessian import Adahessian
2121from .adamp import AdamP
22+ from .adamw import AdamWLegacy
2223from .adan import Adan
2324from .adopt import Adopt
2425from .lamb import Lamb
26+ from .laprop import LaProp
2527from .lars import Lars
2628from .lion import Lion
2729from .lookahead import Lookahead
2830from .madgrad import MADGRAD
29- from .nadam import Nadam
31+ from .mars import Mars
32+ from .nadam import NAdamLegacy
3033from .nadamw import NAdamW
3134from .nvnovograd import NvNovoGrad
32- from .radam import RAdam
35+ from .radam import RAdamLegacy
3336from .rmsprop_tf import RMSpropTF
3437from .sgdp import SGDP
3538from .sgdw import SGDW
@@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
384387 OptimInfo (
385388 name = 'adam' ,
386389 opt_class = optim .Adam ,
387- description = 'torch.optim Adam ( Adaptive Moment Estimation) ' ,
390+ description = 'torch.optim. Adam, Adaptive Moment Estimation' ,
388391 has_betas = True
389392 ),
390393 OptimInfo (
391394 name = 'adamw' ,
392395 opt_class = optim .AdamW ,
393- description = 'torch.optim Adam with decoupled weight decay regularization' ,
396+ description = 'torch.optim.AdamW, Adam with decoupled weight decay' ,
397+ has_betas = True
398+ ),
399+ OptimInfo (
400+ name = 'adamwlegacy' ,
401+ opt_class = AdamWLegacy ,
402+ description = 'legacy impl of AdamW that pre-dates inclusion to torch.optim' ,
394403 has_betas = True
395404 ),
396405 OptimInfo (
@@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
402411 ),
403412 OptimInfo (
404413 name = 'nadam' ,
405- opt_class = Nadam ,
406- description = 'Adam with Nesterov momentum' ,
414+ opt_class = torch .optim .NAdam ,
415+ description = 'torch.optim.NAdam, Adam with Nesterov momentum' ,
416+ has_betas = True
417+ ),
418+ OptimInfo (
419+ name = 'nadamlegacy' ,
420+ opt_class = NAdamLegacy ,
421+ description = 'legacy impl of NAdam that pre-dates inclusion in torch.optim' ,
407422 has_betas = True
408423 ),
409424 OptimInfo (
410425 name = 'nadamw' ,
411426 opt_class = NAdamW ,
412- description = 'Adam with Nesterov momentum and decoupled weight decay' ,
427+ description = 'Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl ' ,
413428 has_betas = True
414429 ),
415430 OptimInfo (
416431 name = 'radam' ,
417- opt_class = RAdam ,
418- description = 'Rectified Adam with variance adaptation' ,
432+ opt_class = torch .optim .RAdam ,
433+ description = 'torch.optim.RAdam, Rectified Adam with variance adaptation' ,
434+ has_betas = True
435+ ),
436+ OptimInfo (
437+ name = 'radamlegacy' ,
438+ opt_class = RAdamLegacy ,
439+ description = 'legacy impl of RAdam that predates inclusion in torch.optim' ,
419440 has_betas = True
420441 ),
442+ OptimInfo (
443+ name = 'radamw' ,
444+ opt_class = torch .optim .RAdam ,
445+ description = 'torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay' ,
446+ has_betas = True ,
447+ defaults = {'decoupled_weight_decay' : True }
448+ ),
421449 OptimInfo (
422450 name = 'adamax' ,
423451 opt_class = optim .Adamax ,
424- description = 'torch.optim Adamax, Adam with infinity norm for more stable updates' ,
452+ description = 'torch.optim. Adamax, Adam with infinity norm for more stable updates' ,
425453 has_betas = True
426454 ),
427455 OptimInfo (
@@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
518546 OptimInfo (
519547 name = 'adadelta' ,
520548 opt_class = optim .Adadelta ,
521- description = 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
549+ description = 'torch.optim. Adadelta, Adapts learning rates based on running windows of gradients'
522550 ),
523551 OptimInfo (
524552 name = 'adagrad' ,
525553 opt_class = optim .Adagrad ,
526- description = 'torch.optim Adagrad, Adapts learning rates using cumulative squared gradients' ,
554+ description = 'torch.optim. Adagrad, Adapts learning rates using cumulative squared gradients' ,
527555 defaults = {'eps' : 1e-8 }
528556 ),
529557 OptimInfo (
@@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
549577 has_betas = True ,
550578 second_order = True ,
551579 ),
580+ OptimInfo (
581+ name = 'laprop' ,
582+ opt_class = LaProp ,
583+ description = 'Separating Momentum and Adaptivity in Adam' ,
584+ has_betas = True ,
585+ ),
552586 OptimInfo (
553587 name = 'lion' ,
554588 opt_class = Lion ,
@@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
569603 has_momentum = True ,
570604 defaults = {'decoupled_decay' : True }
571605 ),
606+ OptimInfo (
607+ name = 'mars' ,
608+ opt_class = Mars ,
609+ description = 'Unleashing the Power of Variance Reduction for Training Large Models' ,
610+ has_betas = True ,
611+ ),
572612 OptimInfo (
573613 name = 'novograd' ,
574614 opt_class = NvNovoGrad ,
@@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
578618 OptimInfo (
579619 name = 'rmsprop' ,
580620 opt_class = optim .RMSprop ,
581- description = 'torch.optim RMSprop, Root Mean Square Propagation' ,
621+ description = 'torch.optim. RMSprop, Root Mean Square Propagation' ,
582622 has_momentum = True ,
583623 defaults = {'alpha' : 0.9 }
584624 ),
0 commit comments