Skip to content

Commit 34c633b

Browse files
add training with schedule learning rate
1 parent de89ca0 commit 34c633b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

deep_keyphrase/base_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import traceback
44
import logging
55
import os
6-
import gc
76
import torch
87
import torch.nn as nn
98
import torch.optim as optim
@@ -28,14 +27,20 @@ def __init__(self, args, model):
2827
self.model = nn.DataParallel(self.model)
2928
self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD])
3029
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
30+
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
31+
self.args.schedule_step,
32+
self.args.schedule_gamma)
3133
self.logger = get_logger('train')
3234
self.train_loader = KeyphraseDataLoader(self.args.train_filename,
3335
self.vocab2id,
3436
self.args.batch_size,
3537
self.args.max_src_len,
3638
self.args.max_oov_count,
3739
self.args.max_target_len,
38-
'train')
40+
'train',
41+
pre_fetch=True,
42+
token_field=args.token_field,
43+
keyphrase_field=args.keyphrase_field)
3944
if self.args.train_from:
4045
self.dest_dir = os.path.dirname(self.args.train_from) + '/'
4146
else:
@@ -84,7 +89,6 @@ def train(self):
8489
step += 1
8590
self.writer.add_scalar('loss', loss, step)
8691
del loss
87-
gc.collect()
8892
if step and step % self.args.save_model_step == 0:
8993
torch.cuda.empty_cache()
9094
self.evaluate_and_save_model(step, epoch)

0 commit comments

Comments
 (0)