Skip to content

Commit 8e7bb0e

Browse files
committed
small fixes
1 parent bae9143 commit 8e7bb0e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

cdqa/reader/reader_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def evaluate(input_file, args, model, tokenizer, prefix=""):
243243
# XLNet uses a more complex post-processing procedure
244244
write_predictions_extended(examples, features, all_results, args.n_best_size,
245245
args.max_answer_length, output_prediction_file,
246-
output_nbest_file, output_null_log_odds_file, args.predict_file,
246+
output_nbest_file, output_null_log_odds_file, input_file,
247247
model.config.start_n_top, model.config.end_n_top,
248248
args.version_2_with_negative, tokenizer, args.verbose_logging)
249249
else:
@@ -253,7 +253,7 @@ def evaluate(input_file, args, model, tokenizer, prefix=""):
253253
args.version_2_with_negative, args.null_score_diff_threshold)
254254

255255
# Evaluate with the official SQuAD script
256-
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
256+
evaluate_options = EVAL_OPTS(data_file=input_file,
257257
pred_file=output_prediction_file,
258258
na_prob_file=output_null_log_odds_file)
259259
results = evaluate_on_squad(evaluate_options)
@@ -361,7 +361,7 @@ def predict(input_file, args, model, tokenizer, prefix=""):
361361
# XLNet uses a more complex post-processing procedure
362362
out_eval, final_prediction = write_predictions_extended(examples, features, all_results, args.n_best_size,
363363
args.max_answer_length, output_prediction_file,
364-
output_nbest_file, output_null_log_odds_file, args.predict_file,
364+
output_nbest_file, output_null_log_odds_file, input_file,
365365
model.config.start_n_top, model.config.end_n_top,
366366
args.version_2_with_negative, tokenizer, args.verbose_logging)
367367
else:

examples/tutorial-train-xlnet-squad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
reader.device = torch.device('cpu')
3131

3232
# save CPU it locally
33-
joblib.dump(reader, os.path.join(reader.output_dir, 'bert_qa_vCPU.joblib'))
33+
joblib.dump(reader, os.path.join(reader.output_dir, 'xlnet_qa_vCPU.joblib'))
3434

3535
# evaluate the model
36-
reader.evaluate(X='dev-v2.0.json')
36+
out_eval, final_prediction = reader.evaluate(X='dev-v2.0.json')

0 commit comments

Comments
 (0)