@@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
9595 precond_dtype: Dtype of the preconditioner.
9696 decoupled_decay: AdamW style decoupled weight decay
9797 flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98- flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
98+ flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
99+ flatten_end_dim: End of flatten range, defaults to -1.
100+ stochastic_weight_decay: Enable random modulation of weight decay
99101 deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
100102 """
101103
@@ -118,6 +120,7 @@ def __init__(
118120 flatten : bool = False ,
119121 flatten_start_dim : int = 2 ,
120122 flatten_end_dim : int = - 1 ,
123+ stochastic_weight_decay : bool = False ,
121124 deterministic : bool = False ,
122125 ):
123126 if not has_opt_einsum :
@@ -147,6 +150,7 @@ def __init__(
147150 flatten = flatten ,
148151 flatten_start_dim = flatten_start_dim ,
149152 flatten_end_dim = flatten_end_dim ,
153+ stochastic_weight_decay = stochastic_weight_decay ,
150154 )
151155 super (Kron , self ).__init__ (params , defaults )
152156
@@ -353,11 +357,15 @@ def step(self, closure=None):
353357 pre_grad = pre_grad .view (p .shape )
354358
355359 # Apply weight decay
356- if group ["weight_decay" ] != 0 :
360+ weight_decay = group ["weight_decay" ]
361+ if weight_decay != 0 :
362+ if group ["stochastic_weight_decay" ]:
363+ weight_decay = 2 * self .rng .random () * weight_decay
364+
357365 if group ["decoupled_decay" ]:
358- p .mul_ (1. - group ["lr" ] * group [ " weight_decay" ] )
366+ p .mul_ (1. - group ["lr" ] * weight_decay )
359367 else :
360- pre_grad .add_ (p , alpha = group [ " weight_decay" ] )
368+ pre_grad .add_ (p , alpha = weight_decay )
361369
362370 # Update parameters
363371 p .add_ (pre_grad , alpha = - group ["lr" ])
0 commit comments