@@ -13,37 +13,40 @@ def __init__(self, base_optimizer, alpha=0.5, k=6):
1313 raise ValueError (f'Invalid slow update rate: { alpha } ' )
1414 if not 1 <= k :
1515 raise ValueError (f'Invalid lookahead steps: { k } ' )
16- self .alpha = alpha
17- self .k = k
16+ defaults = dict (lookahead_alpha = alpha , lookahead_k = k , lookahead_step = 0 )
1817 self .base_optimizer = base_optimizer
1918 self .param_groups = self .base_optimizer .param_groups
2019 self .defaults = base_optimizer .defaults
20+ self .defaults .update (defaults )
2121 self .state = defaultdict (dict )
22- for group in self .param_groups :
23- group ["step_counter" ] = 0
22+ # manually add our defaults to the param groups
23+ for name , default in defaults .items ():
24+ for group in self .param_groups :
25+ group .setdefault (name , default )
2426
25- def update_slow_weights (self , group ):
27+ def update_slow (self , group ):
2628 for fast_p in group ["params" ]:
2729 if fast_p .grad is None :
2830 continue
2931 param_state = self .state [fast_p ]
30- if " slow_buffer" not in param_state :
31- param_state [" slow_buffer" ] = torch .empty_like (fast_p .data )
32- param_state [" slow_buffer" ].copy_ (fast_p .data )
33- slow = param_state [" slow_buffer" ]
34- slow .add_ (self . alpha , fast_p .data - slow )
32+ if ' slow_buffer' not in param_state :
33+ param_state [' slow_buffer' ] = torch .empty_like (fast_p .data )
34+ param_state [' slow_buffer' ].copy_ (fast_p .data )
35+ slow = param_state [' slow_buffer' ]
36+ slow .add_ (group [ 'lookahead_alpha' ] , fast_p .data - slow )
3537 fast_p .data .copy_ (slow )
3638
3739 def sync_lookahead (self ):
3840 for group in self .param_groups :
39- self .update_slow_weights (group )
41+ self .update_slow (group )
4042
4143 def step (self , closure = None ):
44+ #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
4245 loss = self .base_optimizer .step (closure )
4346 for group in self .param_groups :
44- group ['step_counter ' ] += 1
45- if group ['step_counter ' ] % self . k == 0 :
46- self .update_slow_weights (group )
47+ group ['lookahead_step ' ] += 1
48+ if group ['lookahead_step ' ] % group [ 'lookahead_k' ] == 0 :
49+ self .update_slow (group )
4750 return loss
4851
4952 def state_dict (self ):
@@ -52,37 +55,36 @@ def state_dict(self):
5255 (id (k ) if isinstance (k , torch .Tensor ) else k ): v
5356 for k , v in self .state .items ()
5457 }
55- fast_state = fast_state_dict [" state" ]
56- param_groups = fast_state_dict [" param_groups" ]
58+ fast_state = fast_state_dict [' state' ]
59+ param_groups = fast_state_dict [' param_groups' ]
5760 return {
58- " state" : fast_state ,
59- " slow_state" : slow_state ,
60- " param_groups" : param_groups ,
61+ ' state' : fast_state ,
62+ ' slow_state' : slow_state ,
63+ ' param_groups' : param_groups ,
6164 }
6265
6366 def load_state_dict (self , state_dict ):
67+ fast_state_dict = {
68+ 'state' : state_dict ['state' ],
69+ 'param_groups' : state_dict ['param_groups' ],
70+ }
71+ self .base_optimizer .load_state_dict (fast_state_dict )
72+
73+ # We want to restore the slow state, but share param_groups reference
74+ # with base_optimizer. This is a bit redundant but least code
75+ slow_state_new = False
6476 if 'slow_state' not in state_dict :
65- print ('Loading state_dict from optimizer without Lookahead applied' )
77+ print ('Loading state_dict from optimizer without Lookahead applied. ' )
6678 state_dict ['slow_state' ] = defaultdict (dict )
79+ slow_state_new = True
6780 slow_state_dict = {
68- "state" : state_dict ["slow_state" ],
69- "param_groups" : state_dict ["param_groups" ],
70- }
71- fast_state_dict = {
72- "state" : state_dict ["state" ],
73- "param_groups" : state_dict ["param_groups" ],
81+ 'state' : state_dict ['slow_state' ],
82+ 'param_groups' : state_dict ['param_groups' ], # this is pointless but saves code
7483 }
7584 super (Lookahead , self ).load_state_dict (slow_state_dict )
76- self .base_optimizer .load_state_dict (fast_state_dict )
77-
78- def add_param_group (self , param_group ):
79- r"""Add a param group to the :class:`Optimizer` s `param_groups`.
80- This can be useful when fine tuning a pre-trained network as frozen
81- layers can be made trainable and added to the :class:`Optimizer` as
82- training progresses.
83- Args:
84- param_group (dict): Specifies what Tensors should be optimized along
85- with group specific optimization options.
86- """
87- param_group ['step_counter' ] = 0
88- self .base_optimizer .add_param_group (param_group )
85+ self .param_groups = self .base_optimizer .param_groups # make both ref same container
86+ if slow_state_new :
87+ # reapply defaults to catch missing lookahead specific ones
88+ for name , default in self .defaults .items ():
89+ for group in self .param_groups :
90+ group .setdefault (name , default )
0 commit comments