Skip to content

Commit 700aebc

Browse files
committed
Fix Pytorch 2.0 breakage for Lookahead optimizer adapter
1 parent cd950e6 commit 700aebc

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

timm/optim/lookahead.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
55
Hacked together by / Copyright 2020 Ross Wightman
66
"""
7+
from collections import OrderedDict
8+
from typing import Callable, Dict
9+
710
import torch
811
from torch.optim.optimizer import Optimizer
912
from collections import defaultdict
@@ -12,6 +15,8 @@
1215
class Lookahead(Optimizer):
1316
def __init__(self, base_optimizer, alpha=0.5, k=6):
1417
# NOTE super().__init__() not called on purpose
18+
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
19+
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
1520
if not 0.0 <= alpha <= 1.0:
1621
raise ValueError(f'Invalid slow update rate: {alpha}')
1722
if not 1 <= k:

0 commit comments

Comments
 (0)