@@ -47,7 +47,7 @@ def run_single_pass(tf_runner, squad):
4747 tf_runner .set_input_tensor ("input_mask:0" , squad .get_attention_mask_array ())
4848 tf_runner .set_input_tensor ("segment_ids:0" , squad .get_token_type_ids_array ())
4949
50- output = tf_runner .run (batch_size )
50+ output = tf_runner .run (batch_size * seq_size )
5151
5252 for i in range (batch_size ):
5353 answer_start_id , answer_end_id = np .argmax (output ["logits:0" ][i ], axis = 0 )
@@ -85,8 +85,8 @@ def run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, disabl
8585 from utils .pytorch import PyTorchRunner
8686
8787 def run_single_pass (pytorch_runner , squad ):
88-
89- output = pytorch_runner .run (batch_size , ** dict (squad . get_input_arrays () ))
88+ input_tensor = squad . get_input_arrays ()
89+ output = pytorch_runner .run (batch_size * input_tensor [ "input_ids" ]. size ()[ 1 ] , ** dict (input_tensor ))
9090
9191 for i in range (batch_size ):
9292 answer_start_id = output [0 ][i ].argmax ()
@@ -137,8 +137,9 @@ def run_pytorch_cuda(model_path, batch_size, num_runs, timeout, squad_path, disa
137137 from transformers import AutoTokenizer , BertConfig , BertForQuestionAnswering
138138
139139 def run_single_pass (pytorch_runner , squad ):
140-
141- output = pytorch_runner .run (batch_size , ** {k : v .cuda () for k , v in squad .get_input_arrays ().items ()})
140+ input_tensor = squad .get_input_arrays ()
141+ output = pytorch_runner .run (batch_size * input_tensor ["input_ids" ].size ()[1 ],
142+ ** {k : v .cuda () for k , v in input_tensor .items ()})
142143
143144 for i in range (batch_size ):
144145 answer_start_id = output [0 ][i ].argmax ()
@@ -189,9 +190,11 @@ def main():
189190 download_squad_1_1_dataset ()
190191
191192 if args .framework == "tf" :
193+ if args .batch_size > 1 :
194+ print_goodbye_message_and_die ("This model supports only BS=1" )
195+
192196 if args .model_path is None :
193- print_goodbye_message_and_die (
194- "a path to model is unspecified!" )
197+ print_goodbye_message_and_die ("a path to model is unspecified!" )
195198
196199 if args .precision == "fp32" :
197200 run_tf_fp32 (** vars (args ))
0 commit comments