You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #3477
workplace post: https://fb.workplace.com/groups/811751593969209/permalink/1285164823294548/
# TL;DR
* A new `DeviceToHostTensorAwaitable` class is available to wrap the device-to-host data transfer, and defers the `cudaEventSync` call until the data is really used on the host.
* It aims at helping sync-point removal in training optimization which often suffers from cpu-blocking sync points.
# why awaitable
* as shown in the following diagram, a comms op is often better to overlap with another (irrelevant) compute op to better utilize the device capability
* the idea is to **defer** the `wait()` call until running the function that uses the result from the comm op
* a convenient way to achieve this "deferring" behavior is to use the `lazy_awaitable` concept, which is already [implemented in torchrec](https://github.com/meta-pytorch/torchrec/blob/main/torchrec/distributed/types.py#L368)
* diagram of (lazy_)awaitable in torchrec
{F1982900178}
# why device-to-host transfer
* there are scenarios that the on-device data is needed from the host side, such as metrics logging and data-dependent shape operation.
* those pattern creates a device-to-host sync (data transfer) that often blocks the cpu execution, and the correct implementation (with `.to(non_blocking=True)` and cuda event: [PR 3436](#3436)) usually spans across multiple code domain making it difficult to optimize.
* here we borrow the `LazyAwaitable` concept for the device-side comms and wrap the (1) non-blocking device-to-host data transfer, and (2) `cuda_event.wait()` inside a `DeviceToHostTensorAwaitable` class for better user experience.
* diagram of lazy_awaitable for device-to-host data transfer
{F1982900233}
# results
* the "comms check" result is on device and is needed for validation (host-side assertion)
* the `DeviceToHostTensorAwaitable.wait()` **defer** the cudaEventSync until the very end where the result is really needed by host.
* You can see the post-comms computes are scheduled before the assertion on the host side.
{F1982900468}
NOTE: in this version of implementation we don't use a separate stream (as shown in the diagram above) for the non-blocking device-to-host data transfer because usually the data volume is relatively small.
{F1982901286}
Reviewed By: spmex
Differential Revision: D85211205
fbshipit-source-id: 41d03230dd9b190085545cfb76192d59375646c4
0 commit comments