44from torch .utils .data import DataLoader
55
66from ..model import BERTLM , BERT
7+ from .optim_schedule import ScheduledOptim
78
89import tqdm
910
@@ -21,7 +22,7 @@ class BERTTrainer:
2122
2223 def __init__ (self , bert : BERT , vocab_size : int ,
2324 train_dataloader : DataLoader , test_dataloader : DataLoader = None ,
24- lr : float = 1e-4 , betas = (0.9 , 0.999 ), weight_decay : float = 0.01 ,
25+ lr : float = 1e-4 , betas = (0.9 , 0.999 ), weight_decay : float = 0.01 , warmup_steps = 10000 ,
2526 with_cuda : bool = True , cuda_devices = None , log_freq : int = 10 ):
2627 """
2728 :param bert: BERT model which you want to train
@@ -55,6 +56,7 @@ def __init__(self, bert: BERT, vocab_size: int,
5556
5657 # Setting the Adam optimizer with hyper-param
5758 self .optim = Adam (self .model .parameters (), lr = lr , betas = betas , weight_decay = weight_decay )
59+ self .optim_schedule = ScheduledOptim (self .optim , self .bert .hidden , n_warmup_steps = warmup_steps )
5860
5961 # Using Negative Log Likelihood Loss function for predicting the masked_token
6062 self .criterion = nn .NLLLoss (ignore_index = 0 )
@@ -110,9 +112,9 @@ def iteration(self, epoch, data_loader, train=True):
110112
111113 # 3. backward and optimization only in train
112114 if train :
113- self .optim .zero_grad ()
115+ self .optim_schedule .zero_grad ()
114116 loss .backward ()
115- self .optim . step ()
117+ self .optim_schedule . step_and_update_lr ()
116118
117119 # next sentence prediction accuracy
118120 correct = next_sent_output .argmax (dim = - 1 ).eq (data ["is_next" ]).sum ().item ()
0 commit comments