Skip to content

Commit ba3c97c

Browse files
committed
Some Lookahead cleanup and fixes
1 parent fac58f6 commit ba3c97c

File tree

2 files changed

+41
-39
lines changed

2 files changed

+41
-39
lines changed

timm/optim/lookahead.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,40 @@ def __init__(self, base_optimizer, alpha=0.5, k=6):
1313
raise ValueError(f'Invalid slow update rate: {alpha}')
1414
if not 1 <= k:
1515
raise ValueError(f'Invalid lookahead steps: {k}')
16-
self.alpha = alpha
17-
self.k = k
16+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
1817
self.base_optimizer = base_optimizer
1918
self.param_groups = self.base_optimizer.param_groups
2019
self.defaults = base_optimizer.defaults
20+
self.defaults.update(defaults)
2121
self.state = defaultdict(dict)
22-
for group in self.param_groups:
23-
group["step_counter"] = 0
22+
# manually add our defaults to the param groups
23+
for name, default in defaults.items():
24+
for group in self.param_groups:
25+
group.setdefault(name, default)
2426

25-
def update_slow_weights(self, group):
27+
def update_slow(self, group):
2628
for fast_p in group["params"]:
2729
if fast_p.grad is None:
2830
continue
2931
param_state = self.state[fast_p]
30-
if "slow_buffer" not in param_state:
31-
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
32-
param_state["slow_buffer"].copy_(fast_p.data)
33-
slow = param_state["slow_buffer"]
34-
slow.add_(self.alpha, fast_p.data - slow)
32+
if 'slow_buffer' not in param_state:
33+
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
34+
param_state['slow_buffer'].copy_(fast_p.data)
35+
slow = param_state['slow_buffer']
36+
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
3537
fast_p.data.copy_(slow)
3638

3739
def sync_lookahead(self):
3840
for group in self.param_groups:
39-
self.update_slow_weights(group)
41+
self.update_slow(group)
4042

4143
def step(self, closure=None):
44+
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
4245
loss = self.base_optimizer.step(closure)
4346
for group in self.param_groups:
44-
group['step_counter'] += 1
45-
if group['step_counter'] % self.k == 0:
46-
self.update_slow_weights(group)
47+
group['lookahead_step'] += 1
48+
if group['lookahead_step'] % group['lookahead_k'] == 0:
49+
self.update_slow(group)
4750
return loss
4851

4952
def state_dict(self):
@@ -52,37 +55,36 @@ def state_dict(self):
5255
(id(k) if isinstance(k, torch.Tensor) else k): v
5356
for k, v in self.state.items()
5457
}
55-
fast_state = fast_state_dict["state"]
56-
param_groups = fast_state_dict["param_groups"]
58+
fast_state = fast_state_dict['state']
59+
param_groups = fast_state_dict['param_groups']
5760
return {
58-
"state": fast_state,
59-
"slow_state": slow_state,
60-
"param_groups": param_groups,
61+
'state': fast_state,
62+
'slow_state': slow_state,
63+
'param_groups': param_groups,
6164
}
6265

6366
def load_state_dict(self, state_dict):
67+
fast_state_dict = {
68+
'state': state_dict['state'],
69+
'param_groups': state_dict['param_groups'],
70+
}
71+
self.base_optimizer.load_state_dict(fast_state_dict)
72+
73+
# We want to restore the slow state, but share param_groups reference
74+
# with base_optimizer. This is a bit redundant but least code
75+
slow_state_new = False
6476
if 'slow_state' not in state_dict:
65-
print('Loading state_dict from optimizer without Lookahead applied')
77+
print('Loading state_dict from optimizer without Lookahead applied.')
6678
state_dict['slow_state'] = defaultdict(dict)
79+
slow_state_new = True
6780
slow_state_dict = {
68-
"state": state_dict["slow_state"],
69-
"param_groups": state_dict["param_groups"],
70-
}
71-
fast_state_dict = {
72-
"state": state_dict["state"],
73-
"param_groups": state_dict["param_groups"],
81+
'state': state_dict['slow_state'],
82+
'param_groups': state_dict['param_groups'], # this is pointless but saves code
7483
}
7584
super(Lookahead, self).load_state_dict(slow_state_dict)
76-
self.base_optimizer.load_state_dict(fast_state_dict)
77-
78-
def add_param_group(self, param_group):
79-
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
80-
This can be useful when fine tuning a pre-trained network as frozen
81-
layers can be made trainable and added to the :class:`Optimizer` as
82-
training progresses.
83-
Args:
84-
param_group (dict): Specifies what Tensors should be optimized along
85-
with group specific optimization options.
86-
"""
87-
param_group['step_counter'] = 0
88-
self.base_optimizer.add_param_group(param_group)
85+
self.param_groups = self.base_optimizer.param_groups # make both ref same container
86+
if slow_state_new:
87+
# reapply defaults to catch missing lookahead specific ones
88+
for name, default in self.defaults.items():
89+
for group in self.param_groups:
90+
group.setdefault(name, default)

timm/optim/nvnovograd.py

Whitespace-only changes.

0 commit comments

Comments
 (0)