55import logging
66from dataclasses import dataclass
77from functools import partial
8- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union , Protocol , Iterator
8+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
99from fnmatch import fnmatch
1010import importlib
1111
1212import torch
1313import torch .nn as nn
14- import torch .optim as optim
14+ import torch .optim
1515
1616from ._param_groups import param_groups_layer_decay , param_groups_weight_decay
17+ from ._types import ParamsT , OptimType , OptimizerCallable
1718from .adabelief import AdaBelief
1819from .adafactor import Adafactor
1920from .adafactor_bv import AdafactorBigVision
3940
4041_logger = logging .getLogger (__name__ )
4142
42- # Type variables
43- T = TypeVar ('T' )
44- Params = Union [Iterator [nn .Parameter ], Iterator [Dict [str , Any ]]]
45- OptimType = TypeVar ('OptimType' , bound = 'optim.Optimizer' )
46-
4743
4844def _import_class (class_string : str ) -> Type :
4945 """Dynamically import a class from a string."""
@@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
5551 raise ImportError (f"Could not import { class_string } : { e } " )
5652
5753
58- class OptimizerCallable (Protocol ):
59- """Protocol for optimizer constructor signatures."""
60-
61- def __call__ (self , params : Params , ** kwargs ) -> optim .Optimizer : ...
62-
6354
6455@dataclass (frozen = True )
6556class OptimInfo :
@@ -76,7 +67,7 @@ class OptimInfo:
7667 defaults: Optional default parameters for the optimizer
7768 """
7869 name : str
79- opt_class : Union [str , Type [ optim . Optimizer ] ]
70+ opt_class : Union [str , OptimType ]
8071 description : str = ''
8172 has_eps : bool = True
8273 has_momentum : bool = False
@@ -185,7 +176,7 @@ def get_optimizer_class(
185176 self ,
186177 name_or_info : Union [str , OptimInfo ],
187178 bind_defaults : bool = True ,
188- ) -> Union [Type [ optim . Optimizer ] , OptimizerCallable ]:
179+ ) -> Union [OptimType , OptimizerCallable ]:
189180 """Get the optimizer class with any default arguments applied.
190181
191182 This allows direct instantiation of optimizers with their default configs
@@ -234,17 +225,17 @@ def get_optimizer_class(
234225
235226 def create_optimizer (
236227 self ,
237- model_or_params : Union [nn .Module , Params ],
228+ model_or_params : Union [nn .Module , ParamsT ],
238229 opt : str ,
239230 lr : Optional [float ] = None ,
240231 weight_decay : float = 0. ,
241232 momentum : float = 0.9 ,
242233 foreach : Optional [bool ] = None ,
243234 weight_decay_exclude_1d : bool = True ,
244235 layer_decay : Optional [float ] = None ,
245- param_group_fn : Optional [Callable [[nn .Module ], Params ]] = None ,
236+ param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
246237 ** kwargs : Any ,
247- ) -> optim .Optimizer :
238+ ) -> torch . optim .Optimizer :
248239 """Create an optimizer instance.
249240
250241 Args:
@@ -347,15 +338,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
347338 sgd_optimizers = [
348339 OptimInfo (
349340 name = 'sgd' ,
350- opt_class = optim .SGD ,
341+ opt_class = torch . optim .SGD ,
351342 description = 'torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum' ,
352343 has_eps = False ,
353344 has_momentum = True ,
354345 defaults = {'nesterov' : True }
355346 ),
356347 OptimInfo (
357348 name = 'momentum' ,
358- opt_class = optim .SGD ,
349+ opt_class = torch . optim .SGD ,
359350 description = 'torch.Optim Stochastic Gradient Descent (SGD) with classical momentum' ,
360351 has_eps = False ,
361352 has_momentum = True ,
@@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
386377 adam_optimizers = [
387378 OptimInfo (
388379 name = 'adam' ,
389- opt_class = optim .Adam ,
380+ opt_class = torch . optim .Adam ,
390381 description = 'torch.optim.Adam, Adaptive Moment Estimation' ,
391382 has_betas = True
392383 ),
393384 OptimInfo (
394385 name = 'adamw' ,
395- opt_class = optim .AdamW ,
386+ opt_class = torch . optim .AdamW ,
396387 description = 'torch.optim.AdamW, Adam with decoupled weight decay' ,
397388 has_betas = True
398389 ),
@@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
448439 ),
449440 OptimInfo (
450441 name = 'adamax' ,
451- opt_class = optim .Adamax ,
442+ opt_class = torch . optim .Adamax ,
452443 description = 'torch.optim.Adamax, Adam with infinity norm for more stable updates' ,
453444 has_betas = True
454445 ),
@@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
526517 registry .register (opt )
527518
528519
520+ def _register_cautious_optimizers (registry : OptimizerRegistry ) -> None :
521+ cautious_optimizers = [
522+ OptimInfo (
523+ name = 'cadafactor' ,
524+ opt_class = Adafactor ,
525+ description = 'Cautious Adafactor' ,
526+ defaults = {'caution' : True }
527+ ),
528+ OptimInfo (
529+ name = 'cadafactorbv' ,
530+ opt_class = AdafactorBigVision ,
531+ description = 'Cautious Big Vision Adafactor' ,
532+ defaults = {'caution' : True }
533+ ),
534+ OptimInfo (
535+ name = 'cadamw' ,
536+ opt_class = AdamWLegacy ,
537+ description = 'Cautious AdamW' ,
538+ has_betas = True ,
539+ defaults = {'caution' : True }
540+ ),
541+ OptimInfo (
542+ name = 'cadopt' ,
543+ opt_class = Adopt ,
544+ description = 'Cautious Adopt' ,
545+ defaults = {'caution' : True }
546+ ),
547+ OptimInfo (
548+ name = 'cadoptw' ,
549+ opt_class = Adopt ,
550+ description = 'Cautious AdoptW (decoupled decay)' ,
551+ defaults = {'decoupled' : True , 'caution' : True }
552+ ),
553+ OptimInfo (
554+ name = 'clamb' ,
555+ opt_class = Lamb ,
556+ description = 'Cautious LAMB' ,
557+ has_betas = True ,
558+ defaults = {'caution' : True }
559+ ),
560+ OptimInfo (
561+ name = 'claprop' ,
562+ opt_class = LaProp ,
563+ description = 'Cautious LaProp' ,
564+ has_betas = True ,
565+ defaults = {'caution' : True }
566+ ),
567+ OptimInfo (
568+ name = 'clion' ,
569+ opt_class = Lion ,
570+ description = 'Cautious Lion' ,
571+ has_eps = False ,
572+ has_betas = True ,
573+ defaults = {'caution' : True }
574+ ),
575+ OptimInfo (
576+ name = 'cnadamw' ,
577+ opt_class = NAdamW ,
578+ description = 'Cautious NAdamW' ,
579+ has_betas = True ,
580+ defaults = {'caution' : True }
581+ ),
582+ OptimInfo (
583+ name = 'crmsproptf' ,
584+ opt_class = RMSpropTF ,
585+ description = 'Cautious TensorFlow-style RMSprop' ,
586+ has_momentum = True ,
587+ defaults = {'alpha' : 0.9 , 'caution' : True }
588+ ),
589+ OptimInfo (
590+ name = 'csgdw' ,
591+ opt_class = SGDW ,
592+ description = 'Cautious SGD with decoupled weight decay and Nesterov momentum' ,
593+ has_eps = False ,
594+ has_momentum = True ,
595+ defaults = {'nesterov' : True , 'caution' : True }
596+ ),
597+ ]
598+ for opt in cautious_optimizers :
599+ registry .register (opt )
600+
529601def _register_other_optimizers (registry : OptimizerRegistry ) -> None :
530602 """Register miscellaneous optimizers"""
531603 other_optimizers = [
@@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
545617 ),
546618 OptimInfo (
547619 name = 'adadelta' ,
548- opt_class = optim .Adadelta ,
620+ opt_class = torch . optim .Adadelta ,
549621 description = 'torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
550622 ),
551623 OptimInfo (
552624 name = 'adagrad' ,
553- opt_class = optim .Adagrad ,
625+ opt_class = torch . optim .Adagrad ,
554626 description = 'torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients' ,
555627 defaults = {'eps' : 1e-8 }
556628 ),
@@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
617689 ),
618690 OptimInfo (
619691 name = 'rmsprop' ,
620- opt_class = optim .RMSprop ,
692+ opt_class = torch . optim .RMSprop ,
621693 description = 'torch.optim.RMSprop, Root Mean Square Propagation' ,
622694 has_momentum = True ,
623695 defaults = {'alpha' : 0.9 }
@@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
765837 _register_other_optimizers (default_registry )
766838 _register_apex_optimizers (default_registry )
767839 _register_bnb_optimizers (default_registry )
840+ _register_cautious_optimizers (default_registry )
768841
769842 # Register aliases
770843 default_registry .register_alias ('nesterov' , 'sgd' )
@@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
839912def get_optimizer_class (
840913 name : str ,
841914 bind_defaults : bool = True ,
842- ) -> Union [Type [ optim . Optimizer ] , OptimizerCallable ]:
915+ ) -> Union [OptimType , OptimizerCallable ]:
843916 """Get optimizer class by name with option to bind default arguments.
844917
845918 Retrieves the optimizer class or a partial function with default arguments bound.
@@ -874,17 +947,17 @@ def get_optimizer_class(
874947
875948
876949def create_optimizer_v2 (
877- model_or_params : Union [nn .Module , Params ],
950+ model_or_params : Union [nn .Module , ParamsT ],
878951 opt : str = 'sgd' ,
879952 lr : Optional [float ] = None ,
880953 weight_decay : float = 0. ,
881954 momentum : float = 0.9 ,
882955 foreach : Optional [bool ] = None ,
883956 filter_bias_and_bn : bool = True ,
884957 layer_decay : Optional [float ] = None ,
885- param_group_fn : Optional [Callable [[nn .Module ], Params ]] = None ,
958+ param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
886959 ** kwargs : Any ,
887- ) -> optim .Optimizer :
960+ ) -> torch . optim .Optimizer :
888961 """Create an optimizer instance via timm registry.
889962
890963 Creates and configures an optimizer with appropriate parameter groups and settings.
@@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
9851058 return kwargs
9861059
9871060
988- def create_optimizer (args , model , filter_bias_and_bn = True ):
1061+ def create_optimizer (
1062+ args ,
1063+ model : Union [nn .Module , ParamsT ],
1064+ filter_bias_and_bn : bool = True ,
1065+ ) -> torch .optim .Optimizer :
9891066 """ Legacy optimizer factory for backwards compatibility.
9901067 NOTE: Use create_optimizer_v2 for new code.
9911068 """
0 commit comments