Skip to content

Commit f3754c0

Browse files
Fix BERT measuring (#239)
1 parent 86a3c18 commit f3754c0

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

natural_language_processing/extractive_question_answering/bert_large/run_mlperf.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run_single_pass(tf_runner, squad):
4747
tf_runner.set_input_tensor("input_mask:0", squad.get_attention_mask_array())
4848
tf_runner.set_input_tensor("segment_ids:0", squad.get_token_type_ids_array())
4949

50-
output = tf_runner.run(batch_size)
50+
output = tf_runner.run(batch_size * seq_size)
5151

5252
for i in range(batch_size):
5353
answer_start_id, answer_end_id = np.argmax(output["logits:0"][i], axis=0)
@@ -85,8 +85,8 @@ def run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, disabl
8585
from utils.pytorch import PyTorchRunner
8686

8787
def run_single_pass(pytorch_runner, squad):
88-
89-
output = pytorch_runner.run(batch_size, **dict(squad.get_input_arrays()))
88+
input_tensor = squad.get_input_arrays()
89+
output = pytorch_runner.run(batch_size * input_tensor["input_ids"].size()[1], **dict(input_tensor))
9090

9191
for i in range(batch_size):
9292
answer_start_id = output[0][i].argmax()
@@ -137,8 +137,9 @@ def run_pytorch_cuda(model_path, batch_size, num_runs, timeout, squad_path, disa
137137
from transformers import AutoTokenizer, BertConfig, BertForQuestionAnswering
138138

139139
def run_single_pass(pytorch_runner, squad):
140-
141-
output = pytorch_runner.run(batch_size, **{k: v.cuda() for k, v in squad.get_input_arrays().items()})
140+
input_tensor = squad.get_input_arrays()
141+
output = pytorch_runner.run(batch_size * input_tensor["input_ids"].size()[1],
142+
**{k: v.cuda() for k, v in input_tensor.items()})
142143

143144
for i in range(batch_size):
144145
answer_start_id = output[0][i].argmax()
@@ -189,9 +190,11 @@ def main():
189190
download_squad_1_1_dataset()
190191

191192
if args.framework == "tf":
193+
if args.batch_size > 1:
194+
print_goodbye_message_and_die("This model supports only BS=1")
195+
192196
if args.model_path is None:
193-
print_goodbye_message_and_die(
194-
"a path to model is unspecified!")
197+
print_goodbye_message_and_die("a path to model is unspecified!")
195198

196199
if args.precision == "fp32":
197200
run_tf_fp32(**vars(args))

tests/test_pytorch_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_bert_large_mlperf(self):
200200
def wrapper(**kwargs):
201201
kwargs["q"].put(run_pytorch_fp32(**kwargs)[0])
202202

203-
exact_match_ref, f1_ref = 0.792, 0.825
203+
exact_match_ref, f1_ref = 0.750, 0.817
204204
acc = run_process(wrapper, {"model_path": self.model_path, "squad_path": self.dataset_path,
205205
"batch_size": 1, "num_runs": 24, "timeout": None, "disable_jit_freeze": False})
206206
self.assertTrue(acc["exact_match"] / exact_match_ref > 0.95)

utils/nlp/squad.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import json
66
import re
7+
import random
78
import string
89
from collections import Counter
910
import utils.misc as utils
@@ -71,8 +72,12 @@ def __examples(self):
7172
7273
:yield: str, str, list: context, questions, list of possible (correct) answers
7374
"""
75+
random.seed(44)
76+
random.shuffle(self.__dataset)
7477
for section in self.__dataset:
78+
random.shuffle(section["paragraphs"])
7579
for paragraph in section["paragraphs"]:
80+
random.shuffle(paragraph["qas"])
7681
for qas in paragraph["qas"]:
7782
yield paragraph["context"], qas["question"], qas["answers"]
7883

0 commit comments

Comments
 (0)