Skip to content

Commit 04b15e9

Browse files
Dmitrii Dokshinmeta-codesync[bot]
authored andcommitted
Add support for embedding lookup in SDD util (meta-pytorch#3491)
Summary: Pull Request resolved: meta-pytorch#3491 Reviewed By: che-sh Differential Revision: D85642359 fbshipit-source-id: 3de432b5cfc29fca21d07a9ca1244b506c412e14
1 parent 8edd904 commit 04b15e9

File tree

2 files changed

+216
-17
lines changed

2 files changed

+216
-17
lines changed

torchrec/distributed/train_pipeline/pipeline_stage.py

Lines changed: 137 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232

3333
from torchrec.distributed.model_parallel import ShardedModule
3434
from torchrec.distributed.train_pipeline.pipeline_context import (
35+
EmbeddingTrainPipelineContext,
3536
In,
3637
PrefetchTrainPipelineContext,
3738
TrainPipelineContext,
3839
)
3940
from torchrec.distributed.train_pipeline.runtime_forwards import (
4041
BaseForward,
42+
InSyncEmbeddingPipelinedForward,
4143
PipelinedForward,
4244
PrefetchPipelinedForward,
4345
)
@@ -48,6 +50,7 @@
4850
_prefetch_embeddings,
4951
_rewrite_model,
5052
_start_data_dist,
53+
_start_embedding_lookup,
5154
use_context_for_postprocs,
5255
)
5356
from torchrec.distributed.types import Awaitable
@@ -91,7 +94,8 @@ class PipelineStage:
9194
class SparseDataDistUtil(Generic[In]):
9295
"""
9396
Helper class exposing methods for sparse data dist and prefetch pipelining.
94-
Currently used for `StagedTrainPipeline` pipeline stages
97+
Currently used for `StagedTrainPipeline` pipeline stages.\n
98+
Specifying `embedding_lookup_stream` makes `StagedTrainPipeline` semi-synchronous.
9599
96100
Args:
97101
model (torch.nn.Module): Model to pipeline
@@ -100,8 +104,11 @@ class SparseDataDistUtil(Generic[In]):
100104
prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
101105
Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
102106
pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.
107+
embedding_lookup_stream (Optional[torch.cuda.Stream]): Stream on which embedding lookup runs
108+
Defaults to `None`. This needs to be passed in to enable embedding lookup pipelining.
103109
104110
Example::
111+
# With prefetch pipeline:
105112
sdd = SparseDataDistUtil(
106113
model=model,
107114
data_dist_stream=torch.cuda.Stream(),
@@ -129,6 +136,33 @@ class SparseDataDistUtil(Generic[In]):
129136
),
130137
]
131138
139+
# With embedding lookup pipeline:
140+
sdd = SparseDataDistUtil(
141+
model=model,
142+
data_dist_stream=torch.cuda.Stream(),
143+
embedding_lookup_stream=torch.cuda.Stream(), <-- required to enable embedding lookup pipeline
144+
)
145+
pipeline = [
146+
PipelineStage(
147+
name="data_copy",
148+
runnable=lambda batch, context: batch.to(
149+
self._device, non_blocking=True
150+
),
151+
stream=torch.cuda.Stream(),
152+
),
153+
PipelineStage(
154+
name="start_sparse_data_dist",
155+
runnable=sdd.start_sparse_data_dist,
156+
stream=sdd.data_dist_stream,
157+
fill_callback=sdd.wait_sdd_fill_callback,
158+
),
159+
PipelineStage(
160+
name="start_embedding_lookup",
161+
runnable=sdd.start_embedding_lookup,
162+
stream=sdd.embedding_lookup_stream,
163+
),
164+
]
165+
132166
return StagedTrainPipeline(pipeline_stages=pipeline)
133167
"""
134168

@@ -144,12 +178,14 @@ def __init__(
144178
apply_jit: bool = False,
145179
prefetch_stream: Optional[torch.Stream] = None,
146180
pipeline_postproc: bool = False,
181+
embedding_lookup_stream: Optional[torch.Stream] = None,
147182
) -> None:
148183
super().__init__()
149184
self.model = model
150185
self.data_dist_stream = data_dist_stream
151186
self.apply_jit = apply_jit
152187
self.prefetch_stream = prefetch_stream
188+
self.embedding_lookup_stream = embedding_lookup_stream
153189
self._next_index: int = 0
154190
self._contexts: Deque[TrainPipelineContext] = deque()
155191
self.initialized = False
@@ -158,6 +194,10 @@ def __init__(
158194
self.fwd_hook: Optional[RemovableHandle] = None
159195
self._device: torch.device = data_dist_stream.device
160196

197+
assert not (
198+
self._with_prefetch and self._with_embedding_lookup
199+
), "Cannot enable both prefetch and embedding lookup at the same time. Prefetch is redundant with embedding lookup."
200+
161201
self._stream_context: Callable[
162202
[Optional[torch.Stream]], torch.cuda.StreamContext
163203
] = (
@@ -172,10 +212,17 @@ def __init__(
172212
Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]]
173213
] = []
174214

175-
self._pipelined_forward: Type[BaseForward[TrainPipelineContext]] = cast(
176-
Type[BaseForward[TrainPipelineContext]],
177-
(PrefetchPipelinedForward if self._with_prefetch else PipelinedForward),
178-
)
215+
self._pipelined_forward: Type[BaseForward[TrainPipelineContext]]
216+
if self._with_prefetch:
217+
self._pipelined_forward = cast(
218+
Type[BaseForward[TrainPipelineContext]], PrefetchPipelinedForward
219+
)
220+
elif self._with_embedding_lookup:
221+
self._pipelined_forward = cast(
222+
Type[BaseForward[TrainPipelineContext]], InSyncEmbeddingPipelinedForward
223+
)
224+
else:
225+
self._pipelined_forward = PipelinedForward
179226

180227
self._default_stream: Optional[torch.Stream] = (
181228
(torch.get_device_module(self._device).Stream())
@@ -196,6 +243,10 @@ def __init__(
196243
def _with_prefetch(self) -> bool:
197244
return self.prefetch_stream is not None
198245

246+
@property
247+
def _with_embedding_lookup(self) -> bool:
248+
return self.embedding_lookup_stream is not None
249+
199250
def _is_reattaching(self) -> bool:
200251
return len(self._contexts) > 0
201252

@@ -239,7 +290,8 @@ def _pipelined_postprocs_fqns(self) -> Set[str]:
239290
# advancing the list at the beginning of the `progress`.
240291
# Tricky part is that SparseDataDistUtil might be participating in TWO stages:
241292
# * "main" with start_data_dist -> wait_data_dist pair for `runnable` and `fill_callback`
242-
# * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback`
293+
# * "prefetch" with prefetch -> load_prefetch for `runnable` and `fill_callback` (optional)
294+
# * "embedding_lookup" with start_embedding_lookup for `runnable` (optional)
243295
#
244296
# For this to work, we:
245297
# (1) need to manage contexts in a lockstep with batch advancing through stages (_advance_context)
@@ -248,18 +300,18 @@ def _pipelined_postprocs_fqns(self) -> Set[str]:
248300
# (3) set contexts for the _pipelined_modules and _pipelined_postprocs to the "current batch context"
249301
# for the model to run correctly (_set_module_context)
250302
#
251-
# SDD Util uses two or three contexts, depending on if prefetch is present
303+
# SDD Util uses two or three contexts, depending on if prefetch or embedding_lookup is enabled
252304
# * context[0] is always the "current batch" context - used for model forward (outside this class)
253-
# * context[1] is used for prefetch if it is set, and start/wait_sparse_data_dist if not
254-
# * context[2] is used for start/wait_sparse_data_dist if prefetch is not set
305+
# * context[1] is used for prefetch/embedding_lookup if either is set, and start/wait_sparse_data_dist if not
306+
# * context[2] is used for start/wait_sparse_data_dist if prefetch or embedding_lookup is set
255307

256308
def _create_context(self, index: int) -> TrainPipelineContext:
257309
version = self._TRAIN_CONTEXT_VERSION
258-
return (
259-
PrefetchTrainPipelineContext(index=index, version=version)
260-
if self._with_prefetch
261-
else TrainPipelineContext(index=index, version=version)
262-
)
310+
if self._with_prefetch:
311+
return PrefetchTrainPipelineContext(index=index, version=version)
312+
if self._with_embedding_lookup:
313+
return EmbeddingTrainPipelineContext(index=index, version=version)
314+
return TrainPipelineContext(index=index, version=version)
263315

264316
def _add_context(self) -> None:
265317
self._contexts.append(self._create_context(self._next_index))
@@ -283,7 +335,7 @@ def _assert_contexts_count(self) -> None:
283335
if not self._WITH_CONTEXT_ASSERTIONS:
284336
return
285337
contexts_len = len(self._contexts)
286-
expected = 3 if self._with_prefetch else 2
338+
expected = 3 if (self._with_prefetch or self._with_embedding_lookup) else 2
287339
assert (
288340
contexts_len == expected
289341
), f"Expected to have {expected} contexts, but had {contexts_len}"
@@ -325,6 +377,14 @@ def _assert_module_input_post_prefetch(
325377
specified_keys == expected_fqns
326378
), f"Context(idx:{context.index}).module_input_post_prefetch {specified_keys} != pipelined modules fqns {expected_fqns}"
327379

380+
def _assert_embedding_a2a_requests(
381+
self, context: EmbeddingTrainPipelineContext, expected_fqns: Set[str]
382+
) -> None:
383+
specified_keys = context.embedding_a2a_requests.keys()
384+
assert (
385+
specified_keys == expected_fqns
386+
), f"Context(idx:{context.index}).embedding_a2a_requests {specified_keys} != pipelined modules fqns {expected_fqns}"
387+
328388
def _context_for_model_forward(self) -> TrainPipelineContext:
329389
ctx = self._current_context()
330390
if self.should_assert_context_invariants(ctx):
@@ -333,13 +393,16 @@ def _context_for_model_forward(self) -> TrainPipelineContext:
333393
assert isinstance(ctx, PrefetchTrainPipelineContext)
334394
self._assert_module_input_post_prefetch(ctx, target_fqns)
335395
self._assert_module_contexts_post_prefetch(ctx, target_fqns)
396+
elif self._with_embedding_lookup:
397+
assert isinstance(ctx, EmbeddingTrainPipelineContext)
398+
self._assert_embedding_a2a_requests(ctx, target_fqns)
336399
else:
337400
self._assert_input_dist_tensors(ctx, target_fqns)
338401
self._assert_module_contexts(ctx, target_fqns)
339402
return ctx
340403

341404
def _start_dist_context(self) -> TrainPipelineContext:
342-
if self._with_prefetch:
405+
if self._with_prefetch or self._with_embedding_lookup:
343406
ctx = self._contexts[2]
344407
else:
345408
ctx = self._contexts[1]
@@ -367,6 +430,16 @@ def _prefetch_context(self) -> PrefetchTrainPipelineContext:
367430
self._assert_module_contexts(ctx, target_fqns)
368431
return ctx
369432

433+
def _embedding_lookup_context(self) -> EmbeddingTrainPipelineContext:
434+
ctx = self._contexts[1]
435+
assert isinstance(
436+
ctx, EmbeddingTrainPipelineContext
437+
), "Pass embedding_lookup_stream into SparseDataDistUtil to use embedding_lookup_context()"
438+
if self.should_assert_context_invariants(ctx):
439+
target_fqns = self._pipelined_modules_fqns()
440+
self._assert_embedding_a2a_requests(ctx, target_fqns)
441+
return ctx
442+
370443
# ====== End "Named" contexts ====== #
371444

372445
# === End context management === #
@@ -408,7 +481,7 @@ def _initialize_or_reattach(self, batch: In) -> None:
408481
context_for_rewrite = self._current_context()
409482
else:
410483
# if initializing, no contexts are present, so we add them:
411-
if self._with_prefetch:
484+
if self._with_prefetch or self._with_embedding_lookup:
412485
self._contexts.append(self._create_context(-2)) # throwaway context
413486
self._contexts.append(self._create_context(-1)) # throwaway context
414487
self._add_context() # actual context to be used for everything in the initial iteration
@@ -548,3 +621,50 @@ def load_prefetch(self) -> None:
548621
# with version=1, there's nothing to do - they are managed at a context level,
549622
# so this is essentially done by _advance_context + prefetch above
550623
pass
624+
625+
def start_embedding_lookup(self, batch: In) -> In:
626+
"""
627+
Initiates embedding lookup on the embedding_lookup_stream after data distribution.
628+
This enables pipelining of embedding lookup operations independently from the main
629+
model forward pass.
630+
"""
631+
context = self._embedding_lookup_context()
632+
with record_function(f"## start_embedding_lookup {context.index} ##"):
633+
current_stream = torch.get_device_module(self._device).current_stream()
634+
with self._stream_context(self.embedding_lookup_stream):
635+
for module in self._pipelined_modules:
636+
_start_embedding_lookup(
637+
module,
638+
context,
639+
source_stream=self.data_dist_stream,
640+
target_stream=current_stream,
641+
stream_context=self._stream_context,
642+
)
643+
return batch
644+
645+
def start_sparse_data_dist_stage(
646+
self,
647+
name: str = "start_sparse_data_dist",
648+
runnable: Optional[RunnableType] = None,
649+
) -> PipelineStage:
650+
return PipelineStage(
651+
name=name,
652+
runnable=runnable or self.start_sparse_data_dist,
653+
stream=self.data_dist_stream,
654+
fill_callback=self.wait_sdd_fill_callback,
655+
data_exhausted_callback=self.data_exhausted_callback,
656+
)
657+
658+
def start_embedding_lookup_stage(
659+
self,
660+
name: str = "start_embedding_lookup",
661+
runnable: Optional[RunnableType] = None,
662+
) -> PipelineStage:
663+
assert (
664+
self.embedding_lookup_stream is not None
665+
), "embedding_lookup_stream is not set"
666+
return PipelineStage(
667+
name=name,
668+
runnable=runnable or self.start_embedding_lookup,
669+
stream=self.embedding_lookup_stream,
670+
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,85 @@ def gpu_postproc(x: StageOut) -> StageOut:
17131713
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
17141714
torch.testing.assert_close(out, ref_out)
17151715

1716+
# pyre-ignore
1717+
@unittest.skipIf(
1718+
not torch.cuda.is_available(),
1719+
"Not enough GPUs, this test requires at least one GPU",
1720+
)
1721+
def test_pipelining_embedding_lookup(self) -> None:
1722+
model = self._setup_model()
1723+
1724+
sharding_type = ShardingType.TABLE_WISE.value
1725+
kernel_type = EmbeddingComputeKernel.FUSED.value
1726+
1727+
sharded_model, optim = self._generate_sharded_model_and_optimizer(
1728+
model, sharding_type, kernel_type
1729+
)
1730+
(
1731+
sharded_model_pipelined,
1732+
optim_pipelined,
1733+
) = self._generate_sharded_model_and_optimizer(
1734+
model, sharding_type, kernel_type
1735+
)
1736+
1737+
copy_state_dict(
1738+
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
1739+
)
1740+
1741+
num_batches = 12
1742+
data = self._generate_data(
1743+
num_batches=num_batches,
1744+
batch_size=32,
1745+
)
1746+
1747+
non_pipelined_outputs = []
1748+
for batch in data:
1749+
batch = batch.to(self.device)
1750+
optim.zero_grad()
1751+
loss, pred = sharded_model(batch)
1752+
loss.backward()
1753+
optim.step()
1754+
non_pipelined_outputs.append(pred)
1755+
1756+
embedding_lookup_stream = torch.cuda.Stream()
1757+
sdd = SparseDataDistUtil[ModelInput](
1758+
model=sharded_model_pipelined,
1759+
data_dist_stream=torch.cuda.Stream(),
1760+
apply_jit=False,
1761+
embedding_lookup_stream=embedding_lookup_stream,
1762+
)
1763+
1764+
pipeline_stages = [
1765+
PipelineStage(
1766+
name="data_copy",
1767+
runnable=partial(get_h2d_func, device=self.device),
1768+
stream=torch.cuda.Stream(),
1769+
),
1770+
sdd.start_sparse_data_dist_stage(),
1771+
sdd.start_embedding_lookup_stage(),
1772+
]
1773+
pipeline = StagedTrainPipeline(
1774+
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
1775+
)
1776+
dataloader = iter(data)
1777+
1778+
pipelined_out = []
1779+
num_batches_processed = 0
1780+
1781+
while model_in := pipeline.progress(dataloader):
1782+
num_batches_processed += 1
1783+
optim_pipelined.zero_grad()
1784+
loss, pred = sharded_model_pipelined(model_in)
1785+
loss.backward()
1786+
optim_pipelined.step()
1787+
pipelined_out.append(pred)
1788+
1789+
self.assertEqual(num_batches_processed, num_batches)
1790+
1791+
self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
1792+
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
1793+
torch.testing.assert_close(out, ref_out)
1794+
17161795
# pyre-ignore
17171796
@unittest.skipIf(
17181797
not torch.cuda.is_available(),

0 commit comments

Comments
 (0)