88from deep_keyphrase .utils .vocab_loader import load_vocab
99from deep_keyphrase .copy_rnn .model import CopyRNN
1010from deep_keyphrase .base_trainer import BaseTrainer
11- from deep_keyphrase .dataloader import ( KeyphraseDataLoader , TOKENS , TARGET )
11+ from deep_keyphrase .dataloader import TOKENS , TARGET
1212from deep_keyphrase .copy_rnn .predict import CopyRnnPredictor
1313
1414
@@ -87,6 +87,8 @@ def train_batch(self, batch):
8787 torch .nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .grad_norm )
8888
8989 self .optimizer .step ()
90+ if self .args .schedule_lr :
91+ self .scheduler .step ()
9092 return loss
9193
9294 def evaluate (self , step ):
@@ -99,26 +101,49 @@ def evaluate(self, step):
99101 pred_valid_filename += '.batch_{}.pred.jsonl' .format (step )
100102 eval_filename = self .dest_dir + self .args .exp_name + '.batch_{}.eval.json' .format (step )
101103 predictor .eval_predict (self .args .valid_filename , pred_valid_filename ,
102- self .args .eval_batch_size , self .model , True )
103- valid_macro_ret = self .macro_evaluator .evaluate (pred_valid_filename )
104- # valid_micro_ret = self.micro_evaluator.evaluate(pred_valid_filename)
105- for n , counter in valid_macro_ret .items ():
104+ self .args .eval_batch_size , self .model , True ,
105+ token_field = self .args .token_field ,
106+ keyphrase_field = self .args .keyphrase_field )
107+ valid_macro_all_ret = self .macro_evaluator .evaluate (pred_valid_filename )
108+ valid_macro_present_ret = self .macro_evaluator .evaluate (pred_valid_filename , 'present' )
109+ valid_macro_absent_ret = self .macro_evaluator .evaluate (pred_valid_filename , 'absent' )
110+
111+ for n , counter in valid_macro_all_ret .items ():
106112 for k , v in counter .items ():
107113 name = 'valid/macro_{}@{}' .format (k , n )
108114 self .writer .add_scalar (name , v , step )
115+ for n in self .eval_topn :
116+ name = 'present/valid macro_f1@{}' .format (n )
117+ self .writer .add_scalar (name , valid_macro_present_ret [n ]['f1' ], step )
118+ for n in self .eval_topn :
119+ name = 'absent/valid macro_f1@{}' .format (n )
120+ self .writer .add_scalar (name , valid_macro_absent_ret [n ]['f1' ], step )
109121 pred_test_filename = self .dest_dir + self .get_basename (self .args .test_filename )
110122 pred_test_filename += '.batch_{}.pred.jsonl' .format (step )
111123
112124 predictor .eval_predict (self .args .test_filename , pred_test_filename ,
113125 self .args .eval_batch_size , self .model , True )
114- test_macro_ret = self .macro_evaluator .evaluate (pred_test_filename )
115- for n , counter in test_macro_ret .items ():
126+ test_macro_all_ret = self .macro_evaluator .evaluate (pred_test_filename )
127+ test_macro_present_ret = self .macro_evaluator .evaluate (pred_test_filename , 'present' )
128+ test_macro_absent_ret = self .macro_evaluator .evaluate (pred_test_filename , 'absent' )
129+ for n , counter in test_macro_all_ret .items ():
116130 for k , v in counter .items ():
117131 name = 'test/macro_{}@{}' .format (k , n )
118132 self .writer .add_scalar (name , v , step )
119- write_json (eval_filename , {'valid_macro' : valid_macro_ret , 'test_macro' : test_macro_ret })
120- # valid_micro_ret = self.micro_evaluator.evaluate(pred_test_filename)
121- return valid_macro_ret [self .eval_topn [- 1 ]]['f1' ]
133+ for n in self .eval_topn :
134+ name = 'present/test macro_f1@{}' .format (n )
135+ self .writer .add_scalar (name , test_macro_present_ret [n ]['f1' ], step )
136+ for n in self .eval_topn :
137+ name = 'absent/test macro_f1@{}' .format (n )
138+ self .writer .add_scalar (name , test_macro_absent_ret [n ]['f1' ], step )
139+ total_statistics = {'valid_macro' : valid_macro_all_ret ,
140+ 'valid_present_macro' : valid_macro_present_ret ,
141+ 'valid_absent_macro' : valid_macro_absent_ret ,
142+ 'test_macro' : test_macro_all_ret ,
143+ 'test_present_macro' : test_macro_present_ret ,
144+ 'test_absent_macro' : test_macro_absent_ret }
145+ write_json (eval_filename , total_statistics )
146+ return valid_macro_all_ret [self .eval_topn [- 1 ]]['f1' ]
122147
123148 def parse_args (self ):
124149 parser = argparse .ArgumentParser ()
@@ -131,10 +156,12 @@ def parse_args(self):
131156 parser .add_argument ("-vocab_path" , required = True , type = str , help = '' )
132157 parser .add_argument ("-vocab_size" , type = int , default = 500000 , help = '' )
133158 parser .add_argument ("-train_from" , default = '' , type = str , help = '' )
159+ parser .add_argument ("-token_field" , default = 'tokens' , type = str , help = '' )
160+ parser .add_argument ("-keyphrase_field" , default = 'keyphrases' , type = str , help = '' )
134161 parser .add_argument ("-epochs" , type = int , default = 10 , help = '' )
135162 parser .add_argument ("-batch_size" , type = int , default = 64 , help = '' )
136- parser .add_argument ("-learning_rate" , type = float , default = 1e-4 , help = '' )
137- parser .add_argument ("-eval_batch_size" , type = int , default = 20 , help = '' )
163+ parser .add_argument ("-learning_rate" , type = float , default = 1e-3 , help = '' )
164+ parser .add_argument ("-eval_batch_size" , type = int , default = 50 , help = '' )
138165 parser .add_argument ("-dropout" , type = float , default = 0.1 , help = '' )
139166 parser .add_argument ("-grad_norm" , type = float , default = 0.0 , help = '' )
140167 parser .add_argument ("-max_grad" , type = float , default = 5.0 , help = '' )
@@ -144,8 +171,10 @@ def parse_args(self):
144171 parser .add_argument ('-tensorboard_dir' , type = str , default = '' , help = '' )
145172 parser .add_argument ('-logfile' , type = str , default = 'train_log.log' , help = '' )
146173 parser .add_argument ('-save_model_step' , type = int , default = 5000 , help = '' )
147- parser .add_argument ('-early_stop_tolerance' , type = int , default = 50 , help = '' )
148- parser .add_argument ('-train_parallel' , action = 'store_true' , help = '' )
174+ parser .add_argument ('-early_stop_tolerance' , type = int , default = 100 , help = '' )
175+ parser .add_argument ('-schedule_lr' , action = 'store_true' , help = '' )
176+ parser .add_argument ('-schedule_step' , type = int , default = 100000 , help = '' )
177+ parser .add_argument ('-schedule_gamma' , type = float , default = 0.5 , help = '' )
149178
150179 # model specific parameter
151180 parser .add_argument ("-embed_size" , type = int , default = 200 , help = '' )
0 commit comments