Skip to content

Commit e3a253e

Browse files
add schedule learning option
1 parent 24443e2 commit e3a253e

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

deep_keyphrase/copy_rnn/train.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from deep_keyphrase.utils.vocab_loader import load_vocab
99
from deep_keyphrase.copy_rnn.model import CopyRNN
1010
from deep_keyphrase.base_trainer import BaseTrainer
11-
from deep_keyphrase.dataloader import (KeyphraseDataLoader, TOKENS, TARGET)
11+
from deep_keyphrase.dataloader import TOKENS, TARGET
1212
from deep_keyphrase.copy_rnn.predict import CopyRnnPredictor
1313

1414

@@ -87,6 +87,8 @@ def train_batch(self, batch):
8787
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
8888

8989
self.optimizer.step()
90+
if self.args.schedule_lr:
91+
self.scheduler.step()
9092
return loss
9193

9294
def evaluate(self, step):
@@ -99,26 +101,49 @@ def evaluate(self, step):
99101
pred_valid_filename += '.batch_{}.pred.jsonl'.format(step)
100102
eval_filename = self.dest_dir + self.args.exp_name + '.batch_{}.eval.json'.format(step)
101103
predictor.eval_predict(self.args.valid_filename, pred_valid_filename,
102-
self.args.eval_batch_size, self.model, True)
103-
valid_macro_ret = self.macro_evaluator.evaluate(pred_valid_filename)
104-
# valid_micro_ret = self.micro_evaluator.evaluate(pred_valid_filename)
105-
for n, counter in valid_macro_ret.items():
104+
self.args.eval_batch_size, self.model, True,
105+
token_field=self.args.token_field,
106+
keyphrase_field=self.args.keyphrase_field)
107+
valid_macro_all_ret = self.macro_evaluator.evaluate(pred_valid_filename)
108+
valid_macro_present_ret = self.macro_evaluator.evaluate(pred_valid_filename, 'present')
109+
valid_macro_absent_ret = self.macro_evaluator.evaluate(pred_valid_filename, 'absent')
110+
111+
for n, counter in valid_macro_all_ret.items():
106112
for k, v in counter.items():
107113
name = 'valid/macro_{}@{}'.format(k, n)
108114
self.writer.add_scalar(name, v, step)
115+
for n in self.eval_topn:
116+
name = 'present/valid macro_f1@{}'.format(n)
117+
self.writer.add_scalar(name, valid_macro_present_ret[n]['f1'], step)
118+
for n in self.eval_topn:
119+
name = 'absent/valid macro_f1@{}'.format(n)
120+
self.writer.add_scalar(name, valid_macro_absent_ret[n]['f1'], step)
109121
pred_test_filename = self.dest_dir + self.get_basename(self.args.test_filename)
110122
pred_test_filename += '.batch_{}.pred.jsonl'.format(step)
111123

112124
predictor.eval_predict(self.args.test_filename, pred_test_filename,
113125
self.args.eval_batch_size, self.model, True)
114-
test_macro_ret = self.macro_evaluator.evaluate(pred_test_filename)
115-
for n, counter in test_macro_ret.items():
126+
test_macro_all_ret = self.macro_evaluator.evaluate(pred_test_filename)
127+
test_macro_present_ret = self.macro_evaluator.evaluate(pred_test_filename, 'present')
128+
test_macro_absent_ret = self.macro_evaluator.evaluate(pred_test_filename, 'absent')
129+
for n, counter in test_macro_all_ret.items():
116130
for k, v in counter.items():
117131
name = 'test/macro_{}@{}'.format(k, n)
118132
self.writer.add_scalar(name, v, step)
119-
write_json(eval_filename, {'valid_macro': valid_macro_ret, 'test_macro': test_macro_ret})
120-
# valid_micro_ret = self.micro_evaluator.evaluate(pred_test_filename)
121-
return valid_macro_ret[self.eval_topn[-1]]['f1']
133+
for n in self.eval_topn:
134+
name = 'present/test macro_f1@{}'.format(n)
135+
self.writer.add_scalar(name, test_macro_present_ret[n]['f1'], step)
136+
for n in self.eval_topn:
137+
name = 'absent/test macro_f1@{}'.format(n)
138+
self.writer.add_scalar(name, test_macro_absent_ret[n]['f1'], step)
139+
total_statistics = {'valid_macro': valid_macro_all_ret,
140+
'valid_present_macro': valid_macro_present_ret,
141+
'valid_absent_macro': valid_macro_absent_ret,
142+
'test_macro': test_macro_all_ret,
143+
'test_present_macro': test_macro_present_ret,
144+
'test_absent_macro': test_macro_absent_ret}
145+
write_json(eval_filename, total_statistics)
146+
return valid_macro_all_ret[self.eval_topn[-1]]['f1']
122147

123148
def parse_args(self):
124149
parser = argparse.ArgumentParser()
@@ -131,10 +156,12 @@ def parse_args(self):
131156
parser.add_argument("-vocab_path", required=True, type=str, help='')
132157
parser.add_argument("-vocab_size", type=int, default=500000, help='')
133158
parser.add_argument("-train_from", default='', type=str, help='')
159+
parser.add_argument("-token_field", default='tokens', type=str, help='')
160+
parser.add_argument("-keyphrase_field", default='keyphrases', type=str, help='')
134161
parser.add_argument("-epochs", type=int, default=10, help='')
135162
parser.add_argument("-batch_size", type=int, default=64, help='')
136-
parser.add_argument("-learning_rate", type=float, default=1e-4, help='')
137-
parser.add_argument("-eval_batch_size", type=int, default=20, help='')
163+
parser.add_argument("-learning_rate", type=float, default=1e-3, help='')
164+
parser.add_argument("-eval_batch_size", type=int, default=50, help='')
138165
parser.add_argument("-dropout", type=float, default=0.1, help='')
139166
parser.add_argument("-grad_norm", type=float, default=0.0, help='')
140167
parser.add_argument("-max_grad", type=float, default=5.0, help='')
@@ -144,8 +171,10 @@ def parse_args(self):
144171
parser.add_argument('-tensorboard_dir', type=str, default='', help='')
145172
parser.add_argument('-logfile', type=str, default='train_log.log', help='')
146173
parser.add_argument('-save_model_step', type=int, default=5000, help='')
147-
parser.add_argument('-early_stop_tolerance', type=int, default=50, help='')
148-
parser.add_argument('-train_parallel', action='store_true', help='')
174+
parser.add_argument('-early_stop_tolerance', type=int, default=100, help='')
175+
parser.add_argument('-schedule_lr', action='store_true', help='')
176+
parser.add_argument('-schedule_step', type=int, default=100000, help='')
177+
parser.add_argument('-schedule_gamma', type=float, default=0.5, help='')
149178

150179
# model specific parameter
151180
parser.add_argument("-embed_size", type=int, default=200, help='')

0 commit comments

Comments
 (0)