Skip to content

Commit f260a07

Browse files
committed
fixed model eval
1 parent a3e5532 commit f260a07

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

notebooks/seq2seq_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import dill as dpickle
1111
from annoy import AnnoyIndex
12-
from tqdm import tqdm
12+
from tqdm import tqdm, tqdm_notebook
1313
from random import random
1414
from nltk.translate.bleu_score import corpus_bleu
1515

@@ -410,11 +410,16 @@ def evaluate_model(self, holdout_bodies, holdout_titles):
410410
"""
411411
actual, predicted = list(), list()
412412
# step over the whole set
413-
for issue_body, issue_title in zip(holdout_bodies, holdout_titles):
414-
_, yhat = self.generate_issue_title(issue_body)
413+
assert len(holdout_bodies) == len(holdout_titles)
414+
num_examples = len(holdout_bodies)
415415

416-
actual.append(self.pp_title.process_text([issue_title])[0])
416+
logging.warning('Generating predictions.')
417+
for i in tqdm_notebook(range(num_examples)):
418+
_, yhat = self.generate_issue_title(holdout_bodies[i])
419+
420+
actual.append(self.pp_title.process_text([holdout_titles[i]])[0])
417421
predicted.append(self.pp_title.process_text([yhat])[0])
418422
# calculate BLEU score
423+
logging.warning('Calculating BLEU.')
419424
bleu = corpus_bleu(actual, predicted)
420425
return bleu

0 commit comments

Comments
 (0)