Skip to content

Commit dab0360

Browse files
committed
Add NadamW based on mlcommons algorithm, added multi-tensor step
1 parent fb4f220 commit dab0360

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

timm/optim/nadam.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
3232
if not 0.0 <= lr:
3333
raise ValueError("Invalid learning rate: {}".format(lr))
3434
defaults = dict(
35-
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
35+
lr=lr,
36+
betas=betas,
37+
eps=eps,
38+
weight_decay=weight_decay,
39+
schedule_decay=schedule_decay,
40+
)
3641
super(Nadam, self).__init__(params, defaults)
3742

3843
@torch.no_grad()

timm/optim/optim_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .lookahead import Lookahead
2323
from .madgrad import MADGRAD
2424
from .nadam import Nadam
25+
from .nadamw import NAdamW
2526
from .nvnovograd import NvNovoGrad
2627
from .radam import RAdam
2728
from .rmsprop_tf import RMSpropTF
@@ -301,6 +302,8 @@ def create_optimizer_v2(
301302
optimizer = optim.Nadam(parameters, **opt_args)
302303
except AttributeError:
303304
optimizer = Nadam(parameters, **opt_args)
305+
elif opt_lower == 'nadamw':
306+
optimizer = NAdamW(parameters, **opt_args)
304307
elif opt_lower == 'radam':
305308
optimizer = RAdam(parameters, **opt_args)
306309
elif opt_lower == 'adamax':

0 commit comments

Comments
 (0)