We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cd950e6 commit 700aebcCopy full SHA for 700aebc
timm/optim/lookahead.py
@@ -4,6 +4,9 @@
4
5
Hacked together by / Copyright 2020 Ross Wightman
6
"""
7
+from collections import OrderedDict
8
+from typing import Callable, Dict
9
+
10
import torch
11
from torch.optim.optimizer import Optimizer
12
from collections import defaultdict
@@ -12,6 +15,8 @@
15
class Lookahead(Optimizer):
13
16
def __init__(self, base_optimizer, alpha=0.5, k=6):
14
17
# NOTE super().__init__() not called on purpose
18
+ self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
19
+ self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
20
if not 0.0 <= alpha <= 1.0:
21
raise ValueError(f'Invalid slow update rate: {alpha}')
22
if not 1 <= k:
0 commit comments