1- """ PyTorch Implementation of the Kron PSGD optimizer
1+ """ PyTorch Implementation of the Kron ( PSGD) optimizer
22
3- FIXME attribution
4- * https://github.com/evanatyourservice/kron_torch (direct source)
5- * https://github.com/lixilinx/psgd_torch (original)
6- * https://github.com/ClashLuke/HeavyBall (added improvements)
3+ This is a PSGD optimizer using a Kronecker-factored preconditioner.
4+
5+ This impl was adapted from https://github.com/evanatyourservice/kron_torch
6+ by Evan Walters, licensed CC-BY-4.0.
7+
8+ Contributions to above also made by
9+ * Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
10+ * Omead Pooladzandi https://github.com/opooladz
11+
12+ The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
13+
14+ This `timm` impl
15+ * works with a wider variety of torch versions
16+ * fixes some checkpoint save/restore (resume issues)
17+ * adds decoupled weight-decay option
18+ * has some refactoring, cleanup of args, default/group items
19+ * warning about not having opt_einsum (unusable without)
720
821"""
922import logging
3043except AttributeError :
3144 has_dynamo = False
3245
46+ from ._types import ParamsT
47+
3348_logger = logging .getLogger (__name__ )
3449
3550
@@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
85100
86101 def __init__ (
87102 self ,
88- params ,
103+ params : ParamsT ,
89104 lr : float = 0.001 ,
90105 momentum : float = 0.9 ,
91106 weight_decay : float = 0.0 ,
@@ -94,6 +109,8 @@ def __init__(
94109 min_ndim_triangular : int = 2 ,
95110 memory_save_mode : Optional [str ] = None ,
96111 momentum_into_precond_update : bool = True ,
112+ precond_lr : float = 0.1 ,
113+ precond_init_scale : float = 1.0 ,
97114 mu_dtype : Optional [torch .dtype ] = None ,
98115 precond_dtype : Optional [torch .dtype ] = None ,
99116 decoupled_decay : bool = False ,
@@ -119,8 +136,8 @@ def __init__(
119136 min_ndim_triangular = min_ndim_triangular ,
120137 memory_save_mode = memory_save_mode ,
121138 momentum_into_precond_update = momentum_into_precond_update ,
122- precond_lr = 0.1 , # precond lr hardcoded to 0.1
123- precond_init_scale = 1.0 , # precond init scale hardcoded to 1.0
139+ precond_lr = precond_lr ,
140+ precond_init_scale = precond_init_scale ,
124141 mu_dtype = mu_dtype ,
125142 precond_dtype = precond_dtype ,
126143 decoupled_decay = decoupled_decay ,
0 commit comments