Skip to content

Commit a4ca26f

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
Device-to-Host LazyAwaitable (#3477)
Summary: Pull Request resolved: #3477 workplace post: https://fb.workplace.com/groups/811751593969209/permalink/1285164823294548/ # TL;DR * A new `DeviceToHostTensorAwaitable` class is available to wrap the device-to-host data transfer, and defers the `cudaEventSync` call until the data is really used on the host. * It aims at helping sync-point removal in training optimization which often suffers from cpu-blocking sync points. # why awaitable * as shown in the following diagram, a comms op is often better to overlap with another (irrelevant) compute op to better utilize the device capability * the idea is to **defer** the `wait()` call until running the function that uses the result from the comm op * a convenient way to achieve this "deferring" behavior is to use the `lazy_awaitable` concept, which is already [implemented in torchrec](https://github.com/meta-pytorch/torchrec/blob/main/torchrec/distributed/types.py#L368) * diagram of (lazy_)awaitable in torchrec {F1982900178} # why device-to-host transfer * there are scenarios that the on-device data is needed from the host side, such as metrics logging and data-dependent shape operation. * those pattern creates a device-to-host sync (data transfer) that often blocks the cpu execution, and the correct implementation (with `.to(non_blocking=True)` and cuda event: [PR 3436](#3436)) usually spans across multiple code domain making it difficult to optimize. * here we borrow the `LazyAwaitable` concept for the device-side comms and wrap the (1) non-blocking device-to-host data transfer, and (2) `cuda_event.wait()` inside a `DeviceToHostTensorAwaitable` class for better user experience. * diagram of lazy_awaitable for device-to-host data transfer {F1982900233} # results * the "comms check" result is on device and is needed for validation (host-side assertion) * the `DeviceToHostTensorAwaitable.wait()` **defer** the cudaEventSync until the very end where the result is really needed by host. * You can see the post-comms computes are scheduled before the assertion on the host side. {F1982900468} NOTE: in this version of implementation we don't use a separate stream (as shown in the diagram above) for the non-blocking device-to-host data transfer because usually the data volume is relatively small. {F1982901286} Reviewed By: spmex Differential Revision: D85211205 fbshipit-source-id: 41d03230dd9b190085545cfb76192d59375646c4
1 parent d26aa0d commit a4ca26f

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
from dataclasses import dataclass
25-
from typing import Any, Callable, Dict, List, Optional
25+
from typing import Any, Dict, List, Optional
2626

2727
import torch
2828
import torch.distributed as dist
@@ -39,6 +39,7 @@
3939
MultiProcessContext,
4040
run_multi_process_func,
4141
)
42+
from torchrec.distributed.types import DeviceToHostTensorAwaitable
4243

4344
_cc = cmd_conf()
4445

@@ -253,6 +254,46 @@ def a2a_async_twice(
253254
assert checks1 and checks2
254255

255256

257+
# all_to_all_single with sync and single stream
258+
def lazyawaitable(
259+
_batch_inputs: List[Dict[str, Any]],
260+
dim: int,
261+
num_mul: int,
262+
num_concat: int,
263+
ctx: MultiProcessContext,
264+
) -> None:
265+
with record_function("## pre-comms compute ##"):
266+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
267+
268+
with record_function("## all_to_all_single ##"):
269+
# use zeros instead of empty to make sure no previous data used
270+
post_comms = torch.zeros_like(pre_comms)
271+
req = dist.all_to_all_single(
272+
output=post_comms,
273+
input=pre_comms,
274+
group=ctx.pg,
275+
async_op=True,
276+
)
277+
278+
with record_function("## irrelevant compute ##"):
279+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
280+
281+
with record_function("## comms check ##"):
282+
# assertion fails without wait(), this wait() makes the main cuda stream wait
283+
# for the comms to finish, so the post-comms compute will be blocked until
284+
# the comms is done
285+
req.wait()
286+
check_awaitable = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))
287+
288+
with record_function("## post-comms compute ##"):
289+
post_comms = _compute(
290+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
291+
)
292+
293+
with record_function("## assert ##"):
294+
assert check_awaitable.item()
295+
296+
256297
# single-rank runner
257298
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
258299
# Ensure GPUs are available and we have enough of them
@@ -274,8 +315,10 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
274315
func = a2a_async_base
275316
elif arg.name.startswith("a2a_async_twice"):
276317
func = a2a_async_twice
318+
elif arg.name.startswith("lazyawaitable"):
319+
func = lazyawaitable
277320
else:
278-
func = a2a_sync_base
321+
raise ValueError(f"Unknown benchmark name: {arg.name}")
279322

280323
result = benchmark_func(
281324
bench_inputs=[],

torchrec/distributed/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,24 @@ def _wait_impl(self) -> W:
463463
return self._obj
464464

465465

466+
class DeviceToHostTensorAwaitable(LazyAwaitable[torch.Tensor]):
467+
"""An awaitable that waits for a tensor to be copied from device to host."""
468+
469+
def __init__(self, tensor_on_device: torch.Tensor) -> None:
470+
super().__init__()
471+
# self._tensor has unintialized value at this momenet
472+
self._tensor: torch.Tensor = tensor_on_device.to("cpu", non_blocking=True)
473+
474+
# cuda event to record the completion of the copy
475+
self._event = torch.cuda.Event()
476+
self._event.record()
477+
478+
def _wait_impl(self) -> torch.Tensor:
479+
# wait for the copy to complete
480+
self._event.synchronize()
481+
return self._tensor
482+
483+
466484
KT = TypeVar("KT")
467485
VT_co = TypeVar("VT_co")
468486
ParentW = TypeVar("ParentW")

0 commit comments

Comments
 (0)