Skip to content

Commit da91c05

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add inplace_copy_batch_to_gpu in TrainPipeline (#3526)
Summary: Pull Request resolved: #3526 This diff implements support for pre-allocation in-place copy for host-to-device data transfer in TorchRec train pipelines, addressing CUDA memory overhead issues identified in production RecSys models. https://fb.workplace.com/groups/429376538334034/permalink/1497469664858044/ ## Context As described in the [RFC on Workplace](https://fb.workplace.com/groups/429376538334034/permalink/1497469664858044/), we identified an extra CUDA memory overhead of 3-6 GB per rank on top of the active memory snapshot in most RecSys model training pipelines. This overhead stems from PyTorch's caching allocator behavior when using side CUDA streams for non-blocking host-to-device transfers - the allocator associates transferred tensor memory with the side stream, preventing memory reuse in the main stream and causing up to 13GB extra memory footprint per rank in production models. The solution proposed in [D86068070](https://www.internalfb.com/diff/D86068070) enables pre-allocating memory on the main stream and using in-place copy to reduce this overhead. In local train pipeline benchmarks with 1-GB ModelInput (2 KJTs + float features), this approach reduced memory footprint by ~6 GB per rank. This optimization enables many memory-constrained use cases across platforms including APS, Pyper, and MVAI. ## Key Changes: 1. **Added `inplace_copy_batch_to_gpu` parameter**: New boolean flag throughout the train pipeline infrastructure that enables switching between standard batch copying (direct allocation on side stream) and in-place copying (pre-allocation on main stream). 2. **New `inplace_copy_batch_to_gpu()` method**: Implemented in `TrainPipeline` class to handle the new data transfer pattern with proper stream synchronization, using `_to_device()` with the optional `data_copy_stream` parameter. 3. **Extended `Pipelineable.to()` interface**: Added optional `data_copy_stream` parameter to the abstract method, allowing implementations to specify which stream should execute the data copy operation (see #3510). 4. **Updated benchmark configuration** (`sparse_data_dist_base.yml`): - Increased `num_batches` from 5 to 10 - Changed `feature_pooling_avg` from 10 to 30 - Reduced `num_benchmarks` from 2 to 1 - Added `num_profiles: 1` for profiling 5. **Enhanced table configuration**: Added `base_row_size` parameter (default: 100,000) to `EmbeddingTablesConfig` for more flexible embedding table sizing. These changes enable performance and memory comparison between standard and in-place copy strategies, with proper benchmarking infrastructure to measure and trace the differences. Reviewed By: aporialiao Differential Revision: D86208714 fbshipit-source-id: c7bd9d46d1a9f98a68446b9d4be0f63208b626bf
1 parent 88c6058 commit da91c05

File tree

6 files changed

+115
-16
lines changed

6 files changed

+115
-16
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _func_to_benchmark(
183183
model: nn.Module,
184184
pipeline: TrainPipeline,
185185
) -> None:
186+
pipeline.reset()
186187
dataloader = iter(bench_inputs)
187188
while True:
188189
try:

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22
# runs on 2 ranks, showing traces with reasonable workloads
33
RunOptions:
44
world_size: 2
5-
num_batches: 5
6-
num_benchmarks: 2
5+
num_batches: 10
6+
num_benchmarks: 1
7+
num_profiles: 1
78
sharding_type: table_wise
89
profile_dir: "."
910
name: "sparse_data_dist_base"
1011
# export_stacks: True # enable this to export stack traces
1112
PipelineConfig:
1213
pipeline: "sparse"
1314
ModelInputConfig:
14-
feature_pooling_avg: 10
15+
feature_pooling_avg: 30
1516
EmbeddingTablesConfig:
16-
num_unweighted_features: 100
17-
num_weighted_features: 100
17+
num_unweighted_features: 90
18+
num_weighted_features: 80
1819
embedding_feature_dim: 256
1920
additional_tables:
2021
- - name: FP16_table

torchrec/distributed/test_utils/pipeline_config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class PipelineConfig:
4848

4949
pipeline: str = "base"
5050
emb_lookup_stream: str = "data_dist"
51+
inplace_copy_batch_to_gpu: bool = False
5152
apply_jit: bool = False
5253

5354
def generate_pipeline(
@@ -111,14 +112,24 @@ def generate_pipeline(
111112
device=device,
112113
emb_lookup_stream=self.emb_lookup_stream,
113114
apply_jit=self.apply_jit,
115+
inplace_copy_batch_to_gpu=self.inplace_copy_batch_to_gpu,
114116
)
115117
elif self.pipeline == "base":
116118
assert self.apply_jit is False, "JIT is not supported for base pipeline"
117119

118-
return TrainPipelineBase(model=model, optimizer=opt, device=device)
120+
return TrainPipelineBase(
121+
model=model,
122+
optimizer=opt,
123+
device=device,
124+
inplace_copy_batch_to_gpu=self.inplace_copy_batch_to_gpu,
125+
)
119126
else:
120127
Pipeline = _pipeline_cls[self.pipeline]
121128
# pyre-ignore[28]
122129
return Pipeline(
123-
model=model, optimizer=opt, device=device, apply_jit=self.apply_jit
130+
model=model,
131+
optimizer=opt,
132+
device=device,
133+
apply_jit=self.apply_jit,
134+
inplace_copy_batch_to_gpu=self.inplace_copy_batch_to_gpu,
124135
)

torchrec/distributed/test_utils/table_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class EmbeddingTablesConfig:
3636
num_unweighted_features: int = 100
3737
num_weighted_features: int = 100
3838
embedding_feature_dim: int = 128
39+
base_row_size: int = 100_000
3940
table_data_type: DataType = DataType.FP32
4041
additional_tables: List[List[Dict[str, Any]]] = field(default_factory=list)
4142

@@ -71,7 +72,7 @@ def generate_tables(
7172
"""
7273
unweighted_tables = [
7374
EmbeddingBagConfig(
74-
num_embeddings=max(i + 1, 100) * 2000,
75+
num_embeddings=max(i + 1, 100) * self.base_row_size // 100,
7576
embedding_dim=self.embedding_feature_dim,
7677
name="table_" + str(i),
7778
feature_names=["feature_" + str(i)],
@@ -81,7 +82,7 @@ def generate_tables(
8182
]
8283
weighted_tables = [
8384
EmbeddingBagConfig(
84-
num_embeddings=max(i + 1, 100) * 2000,
85+
num_embeddings=max(i + 1, 100) * self.base_row_size // 100,
8586
embedding_dim=self.embedding_feature_dim,
8687
name="weighted_table_" + str(i),
8788
feature_names=["weighted_feature_" + str(i)],

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
custom_model_fwd: Optional[
188188
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
189189
] = None,
190+
inplace_copy_batch_to_gpu: bool = False,
190191
) -> None:
191192
self._model = model
192193
self._optimizer = optimizer
@@ -196,6 +197,7 @@ def __init__(
196197
if device.type in ["cuda", "mtia"]
197198
else None
198199
)
200+
self._inplace_copy_batch_to_gpu = inplace_copy_batch_to_gpu
199201

200202
# pyre-ignore
201203
self._stream_context = (
@@ -217,8 +219,18 @@ def _connect(self, dataloader_iter: Iterator[In]) -> None:
217219
cur_batch = next(dataloader_iter)
218220
self._cur_batch = cur_batch
219221
if cur_batch is not None:
220-
with self._stream_context(self._memcpy_stream):
221-
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
222+
if self._inplace_copy_batch_to_gpu:
223+
self._cur_batch = _to_device(
224+
cur_batch,
225+
self._device,
226+
non_blocking=True,
227+
data_copy_stream=self._memcpy_stream,
228+
)
229+
else:
230+
with self._stream_context(self._memcpy_stream):
231+
self._cur_batch = _to_device(
232+
cur_batch, self._device, non_blocking=True
233+
)
222234
self._connected = True
223235

224236
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
@@ -241,8 +253,18 @@ def _backward(self, losses: torch.Tensor) -> None:
241253

242254
def _copy_batch_to_gpu(self, cur_batch: In) -> None:
243255
with record_function("## copy_batch_to_gpu ##"):
244-
with self._stream_context(self._memcpy_stream):
245-
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
256+
if self._inplace_copy_batch_to_gpu:
257+
self._cur_batch = _to_device(
258+
cur_batch,
259+
self._device,
260+
non_blocking=True,
261+
data_copy_stream=self._memcpy_stream,
262+
)
263+
else:
264+
with self._stream_context(self._memcpy_stream):
265+
self._cur_batch = _to_device(
266+
cur_batch, self._device, non_blocking=True
267+
)
246268

247269
def progress(self, dataloader_iter: Iterator[In]) -> Out:
248270
if not self._connected:
@@ -440,13 +462,15 @@ def __init__(
440462
] = None,
441463
dmp_collection_sync_interval_batches: Optional[int] = 1,
442464
enqueue_batch_after_forward: bool = False,
465+
inplace_copy_batch_to_gpu: bool = False,
443466
) -> None:
444467
self._model = model
445468
self._optimizer = optimizer
446469
self._device = device
447470
self._execute_all_batches = execute_all_batches
448471
self._apply_jit = apply_jit
449472
self._enqueue_batch_after_forward = enqueue_batch_after_forward
473+
self._inplace_copy_batch_to_gpu = inplace_copy_batch_to_gpu
450474

451475
logger.info(
452476
f"enqueue_batch_after_forward: {self._enqueue_batch_after_forward} "
@@ -587,7 +611,10 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
587611
load a data batch from dataloader, and copy it from cpu to gpu
588612
also create the context for this batch.
589613
"""
590-
batch, context = self.copy_batch_to_gpu(dataloader_iter)
614+
if self._inplace_copy_batch_to_gpu:
615+
batch, context = self.inplace_copy_batch_to_gpu(dataloader_iter)
616+
else:
617+
batch, context = self.copy_batch_to_gpu(dataloader_iter)
591618
if batch is None:
592619
return False
593620
self.batches.append(batch)
@@ -820,6 +847,38 @@ def copy_batch_to_gpu(
820847
)
821848
return batch, context
822849

850+
def inplace_copy_batch_to_gpu(
851+
self,
852+
dataloader_iter: Iterator[In],
853+
) -> Tuple[Optional[In], Optional[TrainPipelineContext]]:
854+
"""
855+
Moves batch to the provided device on memcpy stream.
856+
857+
Raises:
858+
StopIteration: if the dataloader iterator is exhausted; unless
859+
`self._execute_all_batches=True`, then returns None.
860+
"""
861+
context = self._create_context()
862+
with record_function(f"## inplace_copy_batch_to_gpu {context.index} ##"):
863+
batch = self._next_batch(dataloader_iter)
864+
if batch is not None:
865+
batch = _to_device(
866+
batch,
867+
self._device,
868+
non_blocking=True,
869+
data_copy_stream=self._memcpy_stream,
870+
)
871+
elif not self._execute_all_batches:
872+
logger.info(
873+
"inplace_copy_batch_to_gpu: raising StopIteration for None Batch (execute_all_batches=False)"
874+
)
875+
raise StopIteration
876+
else:
877+
logger.info(
878+
"inplace_copy_batch_to_gpu: returning None batch (execute_all_batches=True)"
879+
)
880+
return batch, context
881+
823882
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
824883
"""
825884
Retrieves next batch from dataloader and prevents calling `next` on an already
@@ -984,6 +1043,7 @@ def __init__(
9841043
strict: bool = False,
9851044
emb_lookup_stream: str = "data_dist", # new, current, data_dist (default)
9861045
embedding_lookup_after_data_dist: bool = False,
1046+
inplace_copy_batch_to_gpu: bool = False,
9871047
) -> None:
9881048
super().__init__(
9891049
model=model,
@@ -994,6 +1054,7 @@ def __init__(
9941054
context_type=EmbeddingTrainPipelineContext,
9951055
pipeline_postproc=pipeline_postproc,
9961056
custom_model_fwd=custom_model_fwd,
1057+
inplace_copy_batch_to_gpu=inplace_copy_batch_to_gpu,
9971058
)
9981059
self._embedding_lookup_after_data_dist = embedding_lookup_after_data_dist
9991060

@@ -1155,6 +1216,7 @@ def __init__(
11551216
] = None,
11561217
strict: bool = False,
11571218
dmp_collection_sync_interval_batches: Optional[int] = 1,
1219+
inplace_copy_batch_to_gpu: bool = False,
11581220
) -> None:
11591221
super().__init__(
11601222
model=model,
@@ -1166,6 +1228,7 @@ def __init__(
11661228
pipeline_postproc=pipeline_postproc,
11671229
custom_model_fwd=custom_model_fwd,
11681230
dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches,
1231+
inplace_copy_batch_to_gpu=inplace_copy_batch_to_gpu,
11691232
)
11701233
self._start_batch = start_batch
11711234
self._stash_gradients = stash_gradients

torchrec/distributed/train_pipeline/utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,33 @@
6868
logger: logging.Logger = logging.getLogger(__name__)
6969

7070

71-
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
71+
def _to_device(
72+
batch: In,
73+
device: torch.device,
74+
non_blocking: bool,
75+
data_copy_stream: Optional[torch.Stream] = None,
76+
) -> In:
7277
assert isinstance(
7378
batch, (torch.Tensor, Pipelineable)
7479
), f"{type(batch)} must implement Pipelineable interface"
75-
return cast(In, batch.to(device=device, non_blocking=non_blocking))
80+
if data_copy_stream is not None:
81+
return cast(
82+
In,
83+
# pyre-ignore[28]
84+
batch.to(
85+
device=device,
86+
non_blocking=non_blocking,
87+
data_copy_stream=data_copy_stream,
88+
),
89+
)
90+
else:
91+
return cast(
92+
In,
93+
batch.to(
94+
device=device,
95+
non_blocking=non_blocking,
96+
),
97+
)
7698

7799

78100
def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:

0 commit comments

Comments
 (0)