Skip to content

Commit ed28348

Browse files
committed
Merge: [GNMT/PyT] Added synchronization before collecting timers, switched to correct averaging when reporting avg throughput
2 parents 7ddd062 + 327898a commit ed28348

File tree

5 files changed

+21
-10
lines changed

5 files changed

+21
-10
lines changed

PyTorch/Translation/GNMT/seq2seq/inference/translator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def evaluate(self, loader, epoch=0, iteration=0, warmup=0, summary=False):
182182
output = []
183183

184184
for i, (src, indices) in enumerate(loader):
185+
if device.type == 'cuda':
186+
torch.cuda.synchronize()
185187
translate_timer = time.time()
186188
src, src_length = src
187189
stats['total_enc_len'] = int(src_length.sum())
@@ -207,12 +209,14 @@ def evaluate(self, loader, epoch=0, iteration=0, warmup=0, summary=False):
207209
detok = self.tokenizer.detokenize(pred)
208210
output.append(detok)
209211

212+
if device.type == 'cuda':
213+
torch.cuda.synchronize()
210214
elapsed = time.time() - translate_timer
211215
batch_time.update(elapsed, batch_size)
212216

213217
total_tokens = stats['total_dec_len'] + stats['total_enc_len']
214218
ttps = total_tokens / elapsed
215-
tot_tok_per_sec.update(ttps, batch_size)
219+
tot_tok_per_sec.update(ttps, elapsed)
216220

217221
iterations.update(stats['iters'])
218222
enc_seq_len.update(stats['total_enc_len'] / batch_size, batch_size)

PyTorch/Translation/GNMT/seq2seq/train/trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def feed_data(self, data_loader, training=True):
222222

223223
batch_size = data_loader.batch_size
224224

225+
if self.device.type == 'cuda':
226+
torch.cuda.synchronize()
225227
end = time.time()
226228
for i, (src, tgt) in enumerate(data_loader):
227229
self.save_counter += 1
@@ -241,12 +243,14 @@ def feed_data(self, data_loader, training=True):
241243
losses_per_sentence.update(loss_per_sentence, batch_size)
242244

243245
# measure elapsed time
246+
if self.device.type == 'cuda':
247+
torch.cuda.synchronize()
244248
elapsed = time.time() - end
245249
batch_time.update(elapsed)
246-
src_tok_time.update(num_toks['src'] / elapsed)
247-
tgt_tok_time.update(num_toks['tgt'] / elapsed)
250+
src_tok_time.update(num_toks['src'] / elapsed, elapsed)
251+
tgt_tok_time.update(num_toks['tgt'] / elapsed, elapsed)
248252
tot_num_toks = num_toks['tgt'] + num_toks['src']
249-
tot_tok_time.update(tot_num_toks / elapsed)
253+
tot_tok_time.update(tot_num_toks / elapsed, elapsed)
250254
self.loss = losses_per_token.avg
251255

252256
if training and i in eval_iters:
@@ -298,6 +302,8 @@ def feed_data(self, data_loader, training=True):
298302
if rank == 0:
299303
self.save(identifier=identifier)
300304

305+
if self.device.type == 'cuda':
306+
torch.cuda.synchronize()
301307
end = time.time()
302308

303309
tot_tok_time.reduce('sum')

PyTorch/Translation/GNMT/seq2seq/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,13 @@ def setup_seeds(master_seed, epochs, device):
132132

133133
def barrier():
134134
"""
135-
Call torch.distributed.barrier() if distritubed is in use
135+
Call torch.distributed.barrier() if distritubed is in use, else calls
136+
torch.cuda.synchronize() if CUDA is initialized.
136137
"""
137138
if torch.distributed.is_available() and torch.distributed.is_initialized():
138139
torch.distributed.barrier()
140+
elif torch.cuda.is_available() and torch.cuda.is_initialized():
141+
torch.cuda.synchronize()
139142

140143

141144
def get_rank():

PyTorch/Translation/GNMT/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def main():
634634
logging.info(f'Total training time {training_time:.0f} s')
635635

636636
table = TrainingTable()
637-
avg_training_perf = sum(training_perf) / len(training_perf)
637+
avg_training_perf = len(training_perf) / sum(1 / v for v in training_perf)
638638
table.add(utils.get_world_size(), args.train_batch_size, test_bleu,
639639
avg_training_perf, training_time)
640640
if utils.get_rank() == 0:

PyTorch/Translation/GNMT/translate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,10 @@ def main():
352352
latency_table.write('Inference latency', 'fp16',
353353
relative=relative, reverse_speedup=True)
354354

355-
avg_throughput = np.array(stats['throughputs']).mean()
356-
avg_latency = np.array(stats['runtimes']).mean()
357355
summary = {
358-
'eval_throughput': avg_throughput,
356+
'eval_throughput': stats['tokens_per_sec'],
359357
'eval_bleu': stats['bleu'],
360-
'eval_avg_latency': avg_latency,
358+
'eval_avg_latency': np.array(stats['runtimes']).mean(),
361359
}
362360
for p in args.percentiles:
363361
summary[f'eval_{p}%_latency'] = np.percentile(stats['runtimes'], p)

0 commit comments

Comments
 (0)