Skip to content

Commit 7ddd062

Browse files
committed
Merge: [TXL/PyT] Added barriers when reporting time, switched to correct averaging when reporting avg throughput
2 parents 959c677 + 8f82237 commit 7ddd062

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

PyTorch/LanguageModeling/Transformer-XL/pytorch/eval.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,19 @@ def format_log(loss, split, args):
168168
return log_str
169169

170170

171-
def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
171+
def evaluate(
172+
eval_iter, model, device, meters, log_interval, max_size=None, repeat=1
173+
):
172174
total_len, total_loss = 0, 0.
173175
eval_step = 0
174176

175177
log_throughput = 0
176178
log_latency = 0
177179
log_loss = 0
178180

179-
torch.cuda.synchronize()
181+
utils.distributed.barrier()
180182
start_time = time.time()
183+
181184
with torch.no_grad():
182185
mems = None
183186
for _ in range(repeat):
@@ -186,10 +189,12 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
186189
break
187190
eval_step += 1
188191

189-
torch.cuda.synchronize()
192+
utils.distributed.barrier()
190193
start_iter = time.time()
194+
191195
loss, mems = model(data, target, mems)
192-
torch.cuda.synchronize()
196+
197+
utils.distributed.barrier()
193198
elapsed = time.time() - start_iter
194199

195200
loss = loss.float().mean()
@@ -204,7 +209,7 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
204209
target_tokens = target.numel()
205210
throughput = target_tokens / elapsed
206211
throughput = utils.distributed.all_reduce_item(throughput, op='sum')
207-
meters['eval_throughput'].update(throughput)
212+
meters['eval_throughput'].update(throughput, elapsed)
208213
log_throughput += throughput
209214

210215
if eval_step % log_interval == 0:
@@ -238,8 +243,8 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
238243
log_loss = 0
239244

240245
utils.distributed.barrier()
241-
torch.cuda.synchronize()
242246
total_time = time.time() - start_time
247+
243248
logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
244249
total_time, 1000 * total_time / (idx+1)))
245250

@@ -251,13 +256,18 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
251256
def compile_model(model, device, args):
252257
inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
253258
tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
259+
260+
utils.distributed.barrier()
254261
start = time.time()
262+
255263
with torch.no_grad():
256264
mems = None
257265
for _ in range(2):
258266
_, mems = model(inp, tgt, mems)
259-
torch.cuda.synchronize()
267+
268+
utils.distributed.barrier()
260269
stop = time.time()
270+
261271
logging.info(f'Building the model took {stop - start:.2f} seconds')
262272

263273

@@ -450,7 +460,7 @@ def main():
450460
meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
451461
meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
452462

453-
loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
463+
loss = evaluate(iter, model, device, meters, args.log_interval, args.max_size, args.repeat)
454464
perplexity = math.exp(loss)
455465
log_str = format_log(loss, args.split, args)
456466

@@ -476,15 +486,17 @@ def main():
476486
}
477487
with open(data_path, 'wb') as f:
478488
pickle.dump(data, f)
479-
logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
489+
490+
avg_throughput = meters['eval_throughput'].avg
491+
logging.info(f'Throughput Avg: {avg_throughput:.2f} tok/s')
480492
logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
481493
for p in args.percentiles:
482494
logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
483495

484496
logging.info('=' * 100)
485497

486498
summary.update({
487-
'eval_throughput': throughput_data.mean(),
499+
'eval_throughput': avg_throughput,
488500
'eval_avg_latency': 1000 * latency_data.mean(),
489501
})
490502
for p in args.percentiles:

PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
513513
cur_loss = float('inf')
514514
target_tokens = 0
515515
log_step = 0
516+
utils.distributed.barrier()
516517
log_start_time = time.time()
517518

518519
if args.varlen:
@@ -586,16 +587,18 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
586587
cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
587588
train_loss = 0
588589

589-
elapsed = time.time() - log_start_time
590+
utils.distributed.barrier()
591+
current_time = time.time()
592+
elapsed = current_time - log_start_time
590593
avg_elapsed = elapsed / log_step
591594
avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
592-
log_start_time = time.time()
595+
log_start_time = current_time
593596
log_step = 0
594597

595598
lr = optimizer.param_groups[0]['lr']
596599
throughput = target_tokens / elapsed
597600
throughput = utils.distributed.all_reduce_item(throughput, op='sum')
598-
meters['train_throughput'].update(throughput)
601+
meters['train_throughput'].update(throughput, elapsed)
599602
target_tokens = 0
600603

601604
log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
@@ -634,21 +637,26 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
634637
interrupted = timeout_handler.interrupted
635638

636639
if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
640+
utils.distributed.barrier()
637641
eval_start_time = time.time()
642+
638643
val_loss = evaluate(va_iter, model, args)
639644
val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
640645

646+
utils.distributed.barrier()
647+
eval_elapsed = time.time() - eval_start_time
648+
641649
logging.info('-' * 100)
642650
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
643651
'| valid loss {:5.2f}'.format(
644652
train_step // args.eval_interval,
645653
train_step,
646-
(time.time() - eval_start_time),
654+
eval_elapsed,
647655
val_loss,
648656
)
649657

650658
dllogger_data = {
651-
'valid_elapsed': (time.time() - eval_start_time),
659+
'valid_elapsed': eval_elapsed,
652660
'valid_loss': val_loss,
653661
}
654662

@@ -683,6 +691,7 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
683691
scheduler_sparse.step(val_loss)
684692

685693
# subtract eval time from timers for training
694+
utils.distributed.barrier()
686695
log_start_time += time.time() - eval_start_time
687696

688697
if interrupted:
@@ -1022,7 +1031,10 @@ def lr_lambda(step):
10221031
###########################################################################
10231032
# Loop over epochs.
10241033
# At any point you can hit Ctrl + C to break out of training early.
1034+
1035+
utils.distributed.barrier()
10251036
start_time = time.time()
1037+
10261038
with TimeoutHandler() as timeout_handler:
10271039
try:
10281040
for epoch in itertools.count(start=start_epoch):
@@ -1046,6 +1058,7 @@ def lr_lambda(step):
10461058
except KeyboardInterrupt:
10471059
logging.info('-' * 100)
10481060
logging.info('Exiting from training early')
1061+
utils.distributed.barrier()
10491062
elapsed = time.time() - start_time
10501063

10511064
###########################################################################
@@ -1064,9 +1077,13 @@ def lr_lambda(step):
10641077
model.load_state_dict(checkpoint['model_state'])
10651078

10661079
# Run on test data.
1080+
utils.distributed.barrier()
10671081
test_start_time = time.time()
1082+
10681083
test_loss = evaluate(te_iter, model, args)
10691084
test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
1085+
1086+
utils.distributed.barrier()
10701087
test_elapsed = time.time() - test_start_time
10711088

10721089
logging.info('=' * 100)

PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ def init_distributed(cuda):
3737

3838
def barrier():
3939
"""
40-
Call torch.distributed.barrier() if distritubed is in use
40+
Call torch.distributed.barrier() if distritubed is in use, else calls
41+
torch.cuda.synchronize() if CUDA is initialized.
4142
"""
4243
if torch.distributed.is_available() and torch.distributed.is_initialized():
4344
torch.distributed.barrier()
45+
elif torch.cuda.is_available() and torch.cuda.is_initialized():
46+
torch.cuda.synchronize()
4447

4548

4649
def get_rank():

0 commit comments

Comments
 (0)