Skip to content

Commit c44b7ad

Browse files
committed
[Maskrcnn/PyT] Synchronize before reporting DLL time
1 parent 6610c05 commit c44b7ad

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
13-
from ..utils.comm import is_main_process
14-
from ..utils.comm import all_gather
15-
from ..utils.comm import synchronize
13+
from ..utils.comm import is_main_process, all_gather, synchronize, synchronized_timestamp
1614

1715

1816
def compute_on_dataset(model, data_loader, device, steps=-1):
@@ -83,7 +81,7 @@ def inference(
8381
)
8482
dataset = data_loader.dataset
8583
dllogger.log(step="PARAMETER", data={"eval_dataset_name": dataset_name, "eval_num_samples":len(dataset)})
86-
start_time = time.time()
84+
start_time = synchronized_timestamp()
8785
with torch.autograd.profiler.emit_nvtx(enabled=profile):
8886
predictions, latency = compute_on_dataset(model, data_loader, device, steps=steps)
8987
# wait for all processes to complete before measuring the time

PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/engine/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.distributed as dist
99

10-
from maskrcnn_benchmark.utils.comm import get_world_size
10+
from maskrcnn_benchmark.utils.comm import get_world_size, synchronized_timestamp
1111
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
1212

1313
def reduce_loss_dict(loss_dict):
@@ -90,8 +90,8 @@ def do_train(
9090
prefetcher = Prefetcher(data_loader, device)
9191
start_iter = arguments["iteration"]
9292
model.train()
93-
start_training_time = time.time()
94-
end = time.time()
93+
start_training_time = synchronized_timestamp()
94+
end = start_training_time
9595
if use_amp:
9696
scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
9797
for iteration, (images, targets) in enumerate(prefetcher, start_iter):
@@ -169,7 +169,7 @@ def _take_step():
169169
if early_exit:
170170
break
171171

172-
total_training_time = time.time() - start_training_time
172+
total_training_time = synchronized_timestamp() - start_training_time
173173
total_time_str = str(datetime.timedelta(seconds=total_training_time))
174174
dllogger.log(step=tuple(), data={"e2e_train_time": total_training_time,
175175
"train_perf_fps": max_iter * cfg.SOLVER.IMS_PER_BATCH / total_training_time})

PyTorch/Segmentation/MaskRCNN/pytorch/maskrcnn_benchmark/utils/comm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,9 @@ def reduce_dict(input_dict, average=True):
116116
values /= world_size
117117
reduced_dict = {k: v for k, v in zip(names, values)}
118118
return reduced_dict
119+
120+
121+
def synchronized_timestamp():
122+
torch.cuda.synchronize()
123+
return time.time()
124+

0 commit comments

Comments
 (0)