3232
3333from torchrec .distributed .model_parallel import ShardedModule
3434from torchrec .distributed .train_pipeline .pipeline_context import (
35+ EmbeddingTrainPipelineContext ,
3536 In ,
3637 PrefetchTrainPipelineContext ,
3738 TrainPipelineContext ,
3839)
3940from torchrec .distributed .train_pipeline .runtime_forwards import (
4041 BaseForward ,
42+ InSyncEmbeddingPipelinedForward ,
4143 PipelinedForward ,
4244 PrefetchPipelinedForward ,
4345)
4850 _prefetch_embeddings ,
4951 _rewrite_model ,
5052 _start_data_dist ,
53+ _start_embedding_lookup ,
5154 use_context_for_postprocs ,
5255)
5356from torchrec .distributed .types import Awaitable
@@ -91,7 +94,8 @@ class PipelineStage:
9194class SparseDataDistUtil (Generic [In ]):
9295 """
9396 Helper class exposing methods for sparse data dist and prefetch pipelining.
94- Currently used for `StagedTrainPipeline` pipeline stages
97+ Currently used for `StagedTrainPipeline` pipeline stages.\n
98+ Specifying `embedding_lookup_stream` makes `StagedTrainPipeline` semi-synchronous.
9599
96100 Args:
97101 model (torch.nn.Module): Model to pipeline
@@ -100,8 +104,11 @@ class SparseDataDistUtil(Generic[In]):
100104 prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
101105 Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
102106 pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.
107+ embedding_lookup_stream (Optional[torch.cuda.Stream]): Stream on which embedding lookup runs
108+ Defaults to `None`. This needs to be passed in to enable embedding lookup pipelining.
103109
104110 Example::
111+ # With prefetch pipeline:
105112 sdd = SparseDataDistUtil(
106113 model=model,
107114 data_dist_stream=torch.cuda.Stream(),
@@ -129,6 +136,33 @@ class SparseDataDistUtil(Generic[In]):
129136 ),
130137 ]
131138
139+ # With embedding lookup pipeline:
140+ sdd = SparseDataDistUtil(
141+ model=model,
142+ data_dist_stream=torch.cuda.Stream(),
143+ embedding_lookup_stream=torch.cuda.Stream(), <-- required to enable embedding lookup pipeline
144+ )
145+ pipeline = [
146+ PipelineStage(
147+ name="data_copy",
148+ runnable=lambda batch, context: batch.to(
149+ self._device, non_blocking=True
150+ ),
151+ stream=torch.cuda.Stream(),
152+ ),
153+ PipelineStage(
154+ name="start_sparse_data_dist",
155+ runnable=sdd.start_sparse_data_dist,
156+ stream=sdd.data_dist_stream,
157+ fill_callback=sdd.wait_sdd_fill_callback,
158+ ),
159+ PipelineStage(
160+ name="start_embedding_lookup",
161+ runnable=sdd.start_embedding_lookup,
162+ stream=sdd.embedding_lookup_stream,
163+ ),
164+ ]
165+
132166 return StagedTrainPipeline(pipeline_stages=pipeline)
133167 """
134168
@@ -144,12 +178,14 @@ def __init__(
144178 apply_jit : bool = False ,
145179 prefetch_stream : Optional [torch .Stream ] = None ,
146180 pipeline_postproc : bool = False ,
181+ embedding_lookup_stream : Optional [torch .Stream ] = None ,
147182 ) -> None :
148183 super ().__init__ ()
149184 self .model = model
150185 self .data_dist_stream = data_dist_stream
151186 self .apply_jit = apply_jit
152187 self .prefetch_stream = prefetch_stream
188+ self .embedding_lookup_stream = embedding_lookup_stream
153189 self ._next_index : int = 0
154190 self ._contexts : Deque [TrainPipelineContext ] = deque ()
155191 self .initialized = False
@@ -158,6 +194,10 @@ def __init__(
158194 self .fwd_hook : Optional [RemovableHandle ] = None
159195 self ._device : torch .device = data_dist_stream .device
160196
197+ assert not (
198+ self ._with_prefetch and self ._with_embedding_lookup
199+ ), "Cannot enable both prefetch and embedding lookup at the same time. Prefetch is redundant with embedding lookup."
200+
161201 self ._stream_context : Callable [
162202 [Optional [torch .Stream ]], torch .cuda .StreamContext
163203 ] = (
@@ -172,10 +212,17 @@ def __init__(
172212 Callable [[KeyedJaggedTensor ], Awaitable [KJTAllToAllTensorsAwaitable ]]
173213 ] = []
174214
175- self ._pipelined_forward : Type [BaseForward [TrainPipelineContext ]] = cast (
176- Type [BaseForward [TrainPipelineContext ]],
177- (PrefetchPipelinedForward if self ._with_prefetch else PipelinedForward ),
178- )
215+ self ._pipelined_forward : Type [BaseForward [TrainPipelineContext ]]
216+ if self ._with_prefetch :
217+ self ._pipelined_forward = cast (
218+ Type [BaseForward [TrainPipelineContext ]], PrefetchPipelinedForward
219+ )
220+ elif self ._with_embedding_lookup :
221+ self ._pipelined_forward = cast (
222+ Type [BaseForward [TrainPipelineContext ]], InSyncEmbeddingPipelinedForward
223+ )
224+ else :
225+ self ._pipelined_forward = PipelinedForward
179226
180227 self ._default_stream : Optional [torch .Stream ] = (
181228 (torch .get_device_module (self ._device ).Stream ())
@@ -196,6 +243,10 @@ def __init__(
196243 def _with_prefetch (self ) -> bool :
197244 return self .prefetch_stream is not None
198245
246+ @property
247+ def _with_embedding_lookup (self ) -> bool :
248+ return self .embedding_lookup_stream is not None
249+
199250 def _is_reattaching (self ) -> bool :
200251 return len (self ._contexts ) > 0
201252
@@ -239,7 +290,8 @@ def _pipelined_postprocs_fqns(self) -> Set[str]:
239290 # advancing the list at the beginning of the `progress`.
240291 # Tricky part is that SparseDataDistUtil might be participating in TWO stages:
241292 # * "main" with start_data_dist -> wait_data_dist pair for `runnable` and `fill_callback`
242- # * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback`
293+ # * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback` (optional)
294+ # * "embedding_lookup" with start_embedding_lookup for `runnable` (optional)
243295 #
244296 # For this to work, we:
245297 # (1) need to manage contexts in a lockstep with batch advancing through stages (_advance_context)
@@ -248,18 +300,18 @@ def _pipelined_postprocs_fqns(self) -> Set[str]:
248300 # (3) set contexts for the _pipelined_modules and _pipelined_postprocs to the "current batch context"
249301 # for the model to run correctly (_set_module_context)
250302 #
251- # SDD Util uses two or three contexts, depending on if prefetch is present
303+ # SDD Util uses two or three contexts, depending on if prefetch or embedding_lookup is enabled
252304 # * context[0] is always the "current batch" context - used for model forward (outside this class)
253- # * context[1] is used for prefetch if it is set, and start/wait_sparse_data_dist if not
254- # * context[2] is used for start/wait_sparse_data_dist if prefetch is not set
305+ # * context[1] is used for prefetch/embedding_lookup if either is set, and start/wait_sparse_data_dist if not
306+ # * context[2] is used for start/wait_sparse_data_dist if prefetch or embedding_lookup is set
255307
256308 def _create_context (self , index : int ) -> TrainPipelineContext :
257309 version = self ._TRAIN_CONTEXT_VERSION
258- return (
259- PrefetchTrainPipelineContext (index = index , version = version )
260- if self ._with_prefetch
261- else TrainPipelineContext (index = index , version = version )
262- )
310+ if self . _with_prefetch :
311+ return PrefetchTrainPipelineContext (index = index , version = version )
312+ if self ._with_embedding_lookup :
313+ return EmbeddingTrainPipelineContext (index = index , version = version )
314+ return TrainPipelineContext ( index = index , version = version )
263315
264316 def _add_context (self ) -> None :
265317 self ._contexts .append (self ._create_context (self ._next_index ))
@@ -283,7 +335,7 @@ def _assert_contexts_count(self) -> None:
283335 if not self ._WITH_CONTEXT_ASSERTIONS :
284336 return
285337 contexts_len = len (self ._contexts )
286- expected = 3 if self ._with_prefetch else 2
338+ expected = 3 if ( self ._with_prefetch or self . _with_embedding_lookup ) else 2
287339 assert (
288340 contexts_len == expected
289341 ), f"Expected to have { expected } contexts, but had { contexts_len } "
@@ -325,6 +377,14 @@ def _assert_module_input_post_prefetch(
325377 specified_keys == expected_fqns
326378 ), f"Context(idx:{ context .index } ).module_input_post_prefetch { specified_keys } != pipelined modules fqns { expected_fqns } "
327379
380+ def _assert_embedding_a2a_requests (
381+ self , context : EmbeddingTrainPipelineContext , expected_fqns : Set [str ]
382+ ) -> None :
383+ specified_keys = context .embedding_a2a_requests .keys ()
384+ assert (
385+ specified_keys == expected_fqns
386+ ), f"Context(idx:{ context .index } ).embedding_a2a_requests { specified_keys } != pipelined modules fqns { expected_fqns } "
387+
328388 def _context_for_model_forward (self ) -> TrainPipelineContext :
329389 ctx = self ._current_context ()
330390 if self .should_assert_context_invariants (ctx ):
@@ -333,13 +393,16 @@ def _context_for_model_forward(self) -> TrainPipelineContext:
333393 assert isinstance (ctx , PrefetchTrainPipelineContext )
334394 self ._assert_module_input_post_prefetch (ctx , target_fqns )
335395 self ._assert_module_contexts_post_prefetch (ctx , target_fqns )
396+ elif self ._with_embedding_lookup :
397+ assert isinstance (ctx , EmbeddingTrainPipelineContext )
398+ self ._assert_embedding_a2a_requests (ctx , target_fqns )
336399 else :
337400 self ._assert_input_dist_tensors (ctx , target_fqns )
338401 self ._assert_module_contexts (ctx , target_fqns )
339402 return ctx
340403
341404 def _start_dist_context (self ) -> TrainPipelineContext :
342- if self ._with_prefetch :
405+ if self ._with_prefetch or self . _with_embedding_lookup :
343406 ctx = self ._contexts [2 ]
344407 else :
345408 ctx = self ._contexts [1 ]
@@ -367,6 +430,16 @@ def _prefetch_context(self) -> PrefetchTrainPipelineContext:
367430 self ._assert_module_contexts (ctx , target_fqns )
368431 return ctx
369432
433+ def _embedding_lookup_context (self ) -> EmbeddingTrainPipelineContext :
434+ ctx = self ._contexts [1 ]
435+ assert isinstance (
436+ ctx , EmbeddingTrainPipelineContext
437+ ), "Pass embedding_lookup_stream into SparseDataDistUtil to use embedding_lookup_context()"
438+ if self .should_assert_context_invariants (ctx ):
439+ target_fqns = self ._pipelined_modules_fqns ()
440+ self ._assert_embedding_a2a_requests (ctx , target_fqns )
441+ return ctx
442+
370443 # ====== End "Named" contexts ====== #
371444
372445 # === End context management === #
@@ -408,7 +481,7 @@ def _initialize_or_reattach(self, batch: In) -> None:
408481 context_for_rewrite = self ._current_context ()
409482 else :
410483 # if initializing, no contexts are present, so we add them:
411- if self ._with_prefetch :
484+ if self ._with_prefetch or self . _with_embedding_lookup :
412485 self ._contexts .append (self ._create_context (- 2 )) # throwaway context
413486 self ._contexts .append (self ._create_context (- 1 )) # throwaway context
414487 self ._add_context () # actual context to be used for everything in the initial iteration
@@ -548,3 +621,50 @@ def load_prefetch(self) -> None:
548621 # with version=1, there's nothing to do - they are managed at a context level,
549622 # so this is essentially done by _advance_context + prefetch above
550623 pass
624+
625+ def start_embedding_lookup (self , batch : In ) -> In :
626+ """
627+ Initiates embedding lookup on the embedding_lookup_stream after data distribution.
628+ This enables pipelining of embedding lookup operations independently from the main
629+ model forward pass.
630+ """
631+ context = self ._embedding_lookup_context ()
632+ with record_function (f"## start_embedding_lookup { context .index } ##" ):
633+ current_stream = torch .get_device_module (self ._device ).current_stream ()
634+ with self ._stream_context (self .embedding_lookup_stream ):
635+ for module in self ._pipelined_modules :
636+ _start_embedding_lookup (
637+ module ,
638+ context ,
639+ source_stream = self .data_dist_stream ,
640+ target_stream = current_stream ,
641+ stream_context = self ._stream_context ,
642+ )
643+ return batch
644+
645+ def start_sparse_data_dist_stage (
646+ self ,
647+ name : str = "start_sparse_data_dist" ,
648+ runnable : Optional [RunnableType ] = None ,
649+ ) -> PipelineStage :
650+ return PipelineStage (
651+ name = name ,
652+ runnable = runnable or self .start_sparse_data_dist ,
653+ stream = self .data_dist_stream ,
654+ fill_callback = self .wait_sdd_fill_callback ,
655+ data_exhausted_callback = self .data_exhausted_callback ,
656+ )
657+
658+ def start_embedding_lookup_stage (
659+ self ,
660+ name : str = "start_embedding_lookup" ,
661+ runnable : Optional [RunnableType ] = None ,
662+ ) -> PipelineStage :
663+ assert (
664+ self .embedding_lookup_stream is not None
665+ ), "embedding_lookup_stream is not set"
666+ return PipelineStage (
667+ name = name ,
668+ runnable = runnable or self .start_embedding_lookup ,
669+ stream = self .embedding_lookup_stream ,
670+ )
0 commit comments