Skip to content

Commit 691d11f

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add logging for inplace_copy_batch_to_gpu (#3532)
Summary: Pull Request resolved: #3532 # context * add logging for inplace_copy_batch_to_gpu * add argument for prefetch pipeline # run command * negative ``` python torchrec/distributed/benchmark/benchmark_train_pipeline.py \ --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \ --loglevel=info \ --pipeline=base ``` * positive ``` python torchrec/distributed/benchmark/benchmark_train_pipeline.py \ --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \ --loglevel=info \ --inplace_copy_batch_to_gpu=True ``` # output * negative ``` INFO:torchrec.distributed.train_pipeline.train_pipelines:train_pipeline uses inplace_copy_batch_to_gpu: False ``` * positive ``` INFO:torchrec.distributed.train_pipeline.train_pipelines:enqueue_batch_after_forward: False execute_all_batches: True inplace_copy_batch_to_gpu: True ``` Reviewed By: Raahul46, spmex Differential Revision: D83946090 fbshipit-source-id: 96fc15fee2d21f5efdd590aee7a51d081f57d7b5
1 parent 7dcdccb commit 691d11f

File tree

1 file changed

+138
-24
lines changed

1 file changed

+138
-24
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 138 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def __init__(
198198
else None
199199
)
200200
self._inplace_copy_batch_to_gpu = inplace_copy_batch_to_gpu
201+
logger.info(
202+
f"train_pipeline uses inplace_copy_batch_to_gpu: {inplace_copy_batch_to_gpu}"
203+
)
201204

202205
# pyre-ignore
203206
self._stream_context = (
@@ -474,7 +477,8 @@ def __init__(
474477

475478
logger.info(
476479
f"enqueue_batch_after_forward: {self._enqueue_batch_after_forward} "
477-
f"execute_all_batches: {self._execute_all_batches}"
480+
f"execute_all_batches: {self._execute_all_batches} "
481+
f"inplace_copy_batch_to_gpu: {inplace_copy_batch_to_gpu}"
478482
)
479483

480484
if device.type == "cuda":
@@ -1486,30 +1490,91 @@ def start_embedding_lookup(
14861490

14871491
class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
14881492
"""
1489-
This pipeline overlaps device transfer, `ShardedModule.input_dist()`, and cache
1490-
prefetching with forward and backward. This helps hide the all2all latency while
1491-
preserving the training forward / backward ordering.
1492-
1493-
stage 4: forward, backward - uses default CUDA stream
1494-
stage 3: prefetch - uses prefetch CUDA stream
1495-
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
1496-
stage 1: device transfer - uses memcpy CUDA stream
1497-
1498-
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
1499-
To be considered a top-level module, a module can only depend on 'getattr' calls on
1500-
input.
1501-
1502-
Input model must be symbolically traceable with the exception of `ShardedModule` and
1503-
`DistributedDataParallel` modules.
1493+
Advanced 4-stage pipelined training implementation with cache prefetching support.
1494+
1495+
This pipeline extends TrainPipelineSparseDist by adding a dedicated prefetch stage
1496+
that overlaps embedding cache prefetching with computation. It orchestrates four
1497+
concurrent CUDA streams to maximize GPU utilization by hiding memory transfer,
1498+
communication, and cache access latencies behind computation.
1499+
1500+
Pipeline Architecture:
1501+
The pipeline maintains 3 batches in flight, each at different stages:
1502+
1503+
Stage 1 (Batch i+2): Device Transfer
1504+
- Stream: memcpy CUDA stream
1505+
- Operation: Copy batch from CPU to GPU memory
1506+
- Overlap: Runs concurrently with all other stages
1507+
1508+
Stage 2 (Batch i+1): Input Distribution
1509+
- Stream: data_dist CUDA stream
1510+
- Operation: ShardedModule.input_dist() - all-to-all collective communication
1511+
- Overlap: Runs while batch i is being prefetched and processed
1512+
1513+
Stage 3 (Batch i+1): Cache Prefetch
1514+
- Stream: prefetch CUDA stream
1515+
- Operation: Prefetch embeddings from cache to GPU
1516+
- Overlap: Runs while batch i is in forward/backward pass
1517+
1518+
Stage 4 (Batch i): Forward/Backward/Optimizer
1519+
- Stream: default CUDA stream
1520+
- Operation: Model forward pass, loss computation, backward pass, optimizer step
1521+
- Overlap: Uses prefetched data from previous iterations
1522+
1523+
Key Features:
1524+
- Overlaps 4 pipeline stages across 3 batches for maximum throughput
1525+
- Hides embedding cache access latency using dedicated prefetch stream
1526+
- Preserves synchronous training semantics (same loss trajectory as non-pipelined)
1527+
- Supports both training and evaluation modes
1528+
- Compatible with sharded embedding modules (EBC, EC, etc.)
1529+
1530+
Requirements:
1531+
- Input model must be symbolically traceable except for ShardedModule and
1532+
DistributedDataParallel modules
1533+
- ShardedModule.input_dist() is only performed for top-level modules in the
1534+
call graph (modules that only depend on 'getattr' calls on input)
1535+
- Embedding modules must support cache prefetching operations
1536+
1537+
Performance Characteristics:
1538+
- Best suited for models with significant embedding lookup latency
1539+
- Achieves ~1.5-2x throughput improvement over TrainPipelineSparseDist when
1540+
cache prefetching benefits are significant
1541+
- Memory overhead: 3x batch size (3 batches in flight)
1542+
- Additional CUDA stream overhead for prefetch operations
15041543
15051544
Args:
1506-
model (torch.nn.Module): model to pipeline.
1507-
optimizer (torch.optim.Optimizer): optimizer to use.
1508-
device (torch.device): device where device transfer, sparse data dist, prefetch,
1509-
and forward/backward pass will happen.
1510-
execute_all_batches (bool): executes remaining batches in pipeline after
1511-
exhausting dataloader iterator.
1512-
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1545+
model (torch.nn.Module): Model to pipeline. Must contain ShardedModule instances
1546+
for sparse features and support cache prefetching.
1547+
optimizer (torch.optim.Optimizer): Optimizer to use for parameter updates.
1548+
device (torch.device): Device where all pipeline stages will execute (typically
1549+
CUDA device).
1550+
execute_all_batches (bool): If True, executes all remaining batches in pipeline
1551+
after dataloader is exhausted. If False, stops immediately when dataloader
1552+
ends. Default: True.
1553+
apply_jit (bool): If True, applies torch.jit.script to non-pipelined (unsharded)
1554+
modules for additional optimization. Default: False.
1555+
pipeline_postproc (bool): If True, enables pipelining of post-processing
1556+
operations. Default: True.
1557+
custom_model_fwd (Optional[Callable]): Custom forward function to use instead
1558+
of model's default forward. Should return (losses, output) tuple.
1559+
inplace_copy_batch_to_gpu (bool): If True, performs in-place device transfer
1560+
to reduce memory allocations. Default: False.
1561+
1562+
Example:
1563+
>>> model = MyModel()
1564+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
1565+
>>> pipeline = PrefetchTrainPipelineSparseDist(
1566+
... model=model,
1567+
... optimizer=optimizer,
1568+
... device=torch.device("cuda:0"),
1569+
... )
1570+
>>> for batch in dataloader:
1571+
... output = pipeline.progress(iter([batch]))
1572+
... # Training step is complete, output contains predictions
1573+
1574+
See Also:
1575+
- TrainPipelineSparseDist: Base 3-stage pipeline without prefetching
1576+
- TrainPipelineSemiSync: Semi-synchronous training alternative
1577+
- TrainPipelineFusedSparseDist: Pipeline with fused embedding lookup
15131578
"""
15141579

15151580
# The PipelinedForward class that is used in _rewrite_model
@@ -1526,6 +1591,7 @@ def __init__(
15261591
custom_model_fwd: Optional[
15271592
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
15281593
] = None,
1594+
inplace_copy_batch_to_gpu: bool = False,
15291595
) -> None:
15301596
super().__init__(
15311597
model=model,
@@ -1536,6 +1602,7 @@ def __init__(
15361602
context_type=PrefetchTrainPipelineContext,
15371603
pipeline_postproc=pipeline_postproc,
15381604
custom_model_fwd=custom_model_fwd,
1605+
inplace_copy_batch_to_gpu=inplace_copy_batch_to_gpu,
15391606
)
15401607
self._context = PrefetchTrainPipelineContext(version=0)
15411608
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1551,6 +1618,19 @@ def __init__(
15511618
self._batch_ip3: Optional[In] = None
15521619

15531620
def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
1621+
"""
1622+
DEPRECATED: exists for backward compatibility
1623+
Initializes the prefetch pipeline with batches.
1624+
1625+
This method fills the pipeline with initial batches to enable overlapping of
1626+
device transfer, input dist, and cache prefetching operations.
1627+
1628+
Args:
1629+
dataloader_iter: Iterator that produces training batches.
1630+
1631+
Raises:
1632+
StopIteration: if the dataloader iterator is exhausted on the first batch.
1633+
"""
15541634
# pipeline is already filled
15551635
if self._batch_i and self._batch_ip1 and self._batch_ip2:
15561636
return
@@ -1578,6 +1658,27 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
15781658
self._start_sparse_data_dist(self._batch_ip1)
15791659

15801660
def progress(self, dataloader_iter: Iterator[In]) -> Out:
1661+
"""
1662+
Executes one training iteration with prefetch pipelining.
1663+
1664+
This method orchestrates a 4-stage pipeline to overlap:
1665+
- Stage 1: Device transfer (batch i+2) on memcpy stream
1666+
- Stage 2: Input dist (batch i+1) on data_dist stream
1667+
- Stage 3: Cache prefetch (batch i+1) on prefetch stream
1668+
- Stage 4: Forward/backward (batch i) on default stream
1669+
1670+
The pipeline maintains 3 batches in flight to maximize GPU utilization by
1671+
hiding memory transfer and communication latency.
1672+
1673+
Args:
1674+
dataloader_iter: Iterator that produces training batches.
1675+
1676+
Returns:
1677+
Model output from the current batch.
1678+
1679+
Raises:
1680+
StopIteration: if the dataloader iterator is exhausted.
1681+
"""
15811682
self._fill_pipeline(dataloader_iter)
15821683

15831684
if self._model.training:
@@ -1614,7 +1715,20 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
16141715

16151716
def _prefetch(self, batch: Optional[In]) -> None:
16161717
"""
1617-
Waits for input dist to finish, then prefetches data.
1718+
Prefetches embedding data from cache to GPU memory.
1719+
1720+
This method executes on the prefetch stream to overlap cache prefetching
1721+
with the forward pass of the previous batch. It waits for input dist to
1722+
complete, then prefetches embedding data and stores the results in the
1723+
pipeline context for use in the next forward pass.
1724+
1725+
Args:
1726+
batch: The batch to prefetch embeddings for. If None, this method
1727+
returns early without prefetching.
1728+
1729+
Note:
1730+
This operation runs on self._prefetch_stream to enable overlap with
1731+
forward/backward computation on the default stream.
16181732
"""
16191733
if batch is None:
16201734
return

0 commit comments

Comments
 (0)