Skip to content

Commit e31ff4c

Browse files
committed
Adding optim schedule feature for codertimo#17
1 parent c897384 commit e31ff4c

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
'''A wrapper class for optimizer '''
2+
import numpy as np
3+
4+
5+
class ScheduledOptim():
6+
'''A simple wrapper class for learning rate scheduling'''
7+
8+
def __init__(self, optimizer, d_model, n_warmup_steps):
9+
self._optimizer = optimizer
10+
self.n_warmup_steps = n_warmup_steps
11+
self.n_current_steps = 0
12+
self.init_lr = np.power(d_model, -0.5)
13+
14+
def step_and_update_lr(self):
15+
"Step with the inner optimizer"
16+
self._update_learning_rate()
17+
self._optimizer.step()
18+
19+
def zero_grad(self):
20+
"Zero out the gradients by the inner optimizer"
21+
self._optimizer.zero_grad()
22+
23+
def _get_lr_scale(self):
24+
return np.min([
25+
np.power(self.n_current_steps, -0.5),
26+
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
27+
28+
def _update_learning_rate(self):
29+
''' Learning rate scheduling per step '''
30+
31+
self.n_current_steps += 1
32+
lr = self.init_lr * self._get_lr_scale()
33+
34+
for param_group in self._optimizer.param_groups:
35+
param_group['lr'] = lr

bert_pytorch/trainer/pretrain.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.utils.data import DataLoader
55

66
from ..model import BERTLM, BERT
7+
from .optim_schedule import ScheduledOptim
78

89
import tqdm
910

@@ -21,7 +22,7 @@ class BERTTrainer:
2122

2223
def __init__(self, bert: BERT, vocab_size: int,
2324
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
24-
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
25+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
2526
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10):
2627
"""
2728
:param bert: BERT model which you want to train
@@ -55,6 +56,7 @@ def __init__(self, bert: BERT, vocab_size: int,
5556

5657
# Setting the Adam optimizer with hyper-param
5758
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
59+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
5860

5961
# Using Negative Log Likelihood Loss function for predicting the masked_token
6062
self.criterion = nn.NLLLoss(ignore_index=0)
@@ -110,9 +112,9 @@ def iteration(self, epoch, data_loader, train=True):
110112

111113
# 3. backward and optimization only in train
112114
if train:
113-
self.optim.zero_grad()
115+
self.optim_schedule.zero_grad()
114116
loss.backward()
115-
self.optim.step()
117+
self.optim_schedule.step_and_update_lr()
116118

117119
# next sentence prediction accuracy
118120
correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()

0 commit comments

Comments
 (0)