@@ -198,6 +198,9 @@ def __init__(
198198 else None
199199 )
200200 self ._inplace_copy_batch_to_gpu = inplace_copy_batch_to_gpu
201+ logger .info (
202+ f"train_pipeline uses inplace_copy_batch_to_gpu: { inplace_copy_batch_to_gpu } "
203+ )
201204
202205 # pyre-ignore
203206 self ._stream_context = (
@@ -474,7 +477,8 @@ def __init__(
474477
475478 logger .info (
476479 f"enqueue_batch_after_forward: { self ._enqueue_batch_after_forward } "
477- f"execute_all_batches: { self ._execute_all_batches } "
480+ f"execute_all_batches: { self ._execute_all_batches } "
481+ f"inplace_copy_batch_to_gpu: { inplace_copy_batch_to_gpu } "
478482 )
479483
480484 if device .type == "cuda" :
@@ -1486,30 +1490,91 @@ def start_embedding_lookup(
14861490
14871491class PrefetchTrainPipelineSparseDist (TrainPipelineSparseDist [In , Out ]):
14881492 """
1489- This pipeline overlaps device transfer, `ShardedModule.input_dist()`, and cache
1490- prefetching with forward and backward. This helps hide the all2all latency while
1491- preserving the training forward / backward ordering.
1492-
1493- stage 4: forward, backward - uses default CUDA stream
1494- stage 3: prefetch - uses prefetch CUDA stream
1495- stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
1496- stage 1: device transfer - uses memcpy CUDA stream
1497-
1498- `ShardedModule.input_dist()` is only done for top-level modules in the call graph.
1499- To be considered a top-level module, a module can only depend on 'getattr' calls on
1500- input.
1501-
1502- Input model must be symbolically traceable with the exception of `ShardedModule` and
1503- `DistributedDataParallel` modules.
1493+ Advanced 4-stage pipelined training implementation with cache prefetching support.
1494+
1495+ This pipeline extends TrainPipelineSparseDist by adding a dedicated prefetch stage
1496+ that overlaps embedding cache prefetching with computation. It orchestrates four
1497+ concurrent CUDA streams to maximize GPU utilization by hiding memory transfer,
1498+ communication, and cache access latencies behind computation.
1499+
1500+ Pipeline Architecture:
1501+ The pipeline maintains 3 batches in flight, each at different stages:
1502+
1503+ Stage 1 (Batch i+2): Device Transfer
1504+ - Stream: memcpy CUDA stream
1505+ - Operation: Copy batch from CPU to GPU memory
1506+ - Overlap: Runs concurrently with all other stages
1507+
1508+ Stage 2 (Batch i+1): Input Distribution
1509+ - Stream: data_dist CUDA stream
1510+ - Operation: ShardedModule.input_dist() - all-to-all collective communication
1511+ - Overlap: Runs while batch i is being prefetched and processed
1512+
1513+ Stage 3 (Batch i+1): Cache Prefetch
1514+ - Stream: prefetch CUDA stream
1515+ - Operation: Prefetch embeddings from cache to GPU
1516+ - Overlap: Runs while batch i is in forward/backward pass
1517+
1518+ Stage 4 (Batch i): Forward/Backward/Optimizer
1519+ - Stream: default CUDA stream
1520+ - Operation: Model forward pass, loss computation, backward pass, optimizer step
1521+ - Overlap: Uses prefetched data from previous iterations
1522+
1523+ Key Features:
1524+ - Overlaps 4 pipeline stages across 3 batches for maximum throughput
1525+ - Hides embedding cache access latency using dedicated prefetch stream
1526+ - Preserves synchronous training semantics (same loss trajectory as non-pipelined)
1527+ - Supports both training and evaluation modes
1528+ - Compatible with sharded embedding modules (EBC, EC, etc.)
1529+
1530+ Requirements:
1531+ - Input model must be symbolically traceable except for ShardedModule and
1532+ DistributedDataParallel modules
1533+ - ShardedModule.input_dist() is only performed for top-level modules in the
1534+ call graph (modules that only depend on 'getattr' calls on input)
1535+ - Embedding modules must support cache prefetching operations
1536+
1537+ Performance Characteristics:
1538+ - Best suited for models with significant embedding lookup latency
1539+ - Achieves ~1.5-2x throughput improvement over TrainPipelineSparseDist when
1540+ cache prefetching benefits are significant
1541+ - Memory overhead: 3x batch size (3 batches in flight)
1542+ - Additional CUDA stream overhead for prefetch operations
15041543
15051544 Args:
1506- model (torch.nn.Module): model to pipeline.
1507- optimizer (torch.optim.Optimizer): optimizer to use.
1508- device (torch.device): device where device transfer, sparse data dist, prefetch,
1509- and forward/backward pass will happen.
1510- execute_all_batches (bool): executes remaining batches in pipeline after
1511- exhausting dataloader iterator.
1512- apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1545+ model (torch.nn.Module): Model to pipeline. Must contain ShardedModule instances
1546+ for sparse features and support cache prefetching.
1547+ optimizer (torch.optim.Optimizer): Optimizer to use for parameter updates.
1548+ device (torch.device): Device where all pipeline stages will execute (typically
1549+ CUDA device).
1550+ execute_all_batches (bool): If True, executes all remaining batches in pipeline
1551+ after dataloader is exhausted. If False, stops immediately when dataloader
1552+ ends. Default: True.
1553+ apply_jit (bool): If True, applies torch.jit.script to non-pipelined (unsharded)
1554+ modules for additional optimization. Default: False.
1555+ pipeline_postproc (bool): If True, enables pipelining of post-processing
1556+ operations. Default: True.
1557+ custom_model_fwd (Optional[Callable]): Custom forward function to use instead
1558+ of model's default forward. Should return (losses, output) tuple.
1559+ inplace_copy_batch_to_gpu (bool): If True, performs in-place device transfer
1560+ to reduce memory allocations. Default: False.
1561+
1562+ Example:
1563+ >>> model = MyModel()
1564+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
1565+ >>> pipeline = PrefetchTrainPipelineSparseDist(
1566+ ... model=model,
1567+ ... optimizer=optimizer,
1568+ ... device=torch.device("cuda:0"),
1569+ ... )
1570+ >>> for batch in dataloader:
1571+ ... output = pipeline.progress(iter([batch]))
1572+ ... # Training step is complete, output contains predictions
1573+
1574+ See Also:
1575+ - TrainPipelineSparseDist: Base 3-stage pipeline without prefetching
1576+ - TrainPipelineSemiSync: Semi-synchronous training alternative
1577+ - TrainPipelineFusedSparseDist: Pipeline with fused embedding lookup
15131578 """
15141579
15151580 # The PipelinedForward class that is used in _rewrite_model
@@ -1526,6 +1591,7 @@ def __init__(
15261591 custom_model_fwd : Optional [
15271592 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
15281593 ] = None ,
1594+ inplace_copy_batch_to_gpu : bool = False ,
15291595 ) -> None :
15301596 super ().__init__ (
15311597 model = model ,
@@ -1536,6 +1602,7 @@ def __init__(
15361602 context_type = PrefetchTrainPipelineContext ,
15371603 pipeline_postproc = pipeline_postproc ,
15381604 custom_model_fwd = custom_model_fwd ,
1605+ inplace_copy_batch_to_gpu = inplace_copy_batch_to_gpu ,
15391606 )
15401607 self ._context = PrefetchTrainPipelineContext (version = 0 )
15411608 self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1551,6 +1618,19 @@ def __init__(
15511618 self ._batch_ip3 : Optional [In ] = None
15521619
15531620 def _fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
1621+ """
1622+ DEPRECATED: exists for backward compatibility
1623+ Initializes the prefetch pipeline with batches.
1624+
1625+ This method fills the pipeline with initial batches to enable overlapping of
1626+ device transfer, input dist, and cache prefetching operations.
1627+
1628+ Args:
1629+ dataloader_iter: Iterator that produces training batches.
1630+
1631+ Raises:
1632+ StopIteration: if the dataloader iterator is exhausted on the first batch.
1633+ """
15541634 # pipeline is already filled
15551635 if self ._batch_i and self ._batch_ip1 and self ._batch_ip2 :
15561636 return
@@ -1578,6 +1658,27 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
15781658 self ._start_sparse_data_dist (self ._batch_ip1 )
15791659
15801660 def progress (self , dataloader_iter : Iterator [In ]) -> Out :
1661+ """
1662+ Executes one training iteration with prefetch pipelining.
1663+
1664+ This method orchestrates a 4-stage pipeline to overlap:
1665+ - Stage 1: Device transfer (batch i+2) on memcpy stream
1666+ - Stage 2: Input dist (batch i+1) on data_dist stream
1667+ - Stage 3: Cache prefetch (batch i+1) on prefetch stream
1668+ - Stage 4: Forward/backward (batch i) on default stream
1669+
1670+ The pipeline maintains 3 batches in flight to maximize GPU utilization by
1671+ hiding memory transfer and communication latency.
1672+
1673+ Args:
1674+ dataloader_iter: Iterator that produces training batches.
1675+
1676+ Returns:
1677+ Model output from the current batch.
1678+
1679+ Raises:
1680+ StopIteration: if the dataloader iterator is exhausted.
1681+ """
15811682 self ._fill_pipeline (dataloader_iter )
15821683
15831684 if self ._model .training :
@@ -1614,7 +1715,20 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
16141715
16151716 def _prefetch (self , batch : Optional [In ]) -> None :
16161717 """
1617- Waits for input dist to finish, then prefetches data.
1718+ Prefetches embedding data from cache to GPU memory.
1719+
1720+ This method executes on the prefetch stream to overlap cache prefetching
1721+ with the forward pass of the previous batch. It waits for input dist to
1722+ complete, then prefetches embedding data and stores the results in the
1723+ pipeline context for use in the next forward pass.
1724+
1725+ Args:
1726+ batch: The batch to prefetch embeddings for. If None, this method
1727+ returns early without prefetching.
1728+
1729+ Note:
1730+ This operation runs on self._prefetch_stream to enable overlap with
1731+ forward/backward computation on the default stream.
16181732 """
16191733 if batch is None :
16201734 return
0 commit comments