Skip to content

Commit 8edd904

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
Add device parameter to KeyedJaggedTensor.empty_like and copy_ method (meta-pytorch#3510)
Summary: Pull Request resolved: meta-pytorch#3510 This diff enhances the KeyedJaggedTensor API to support device-aware operations, which is needed for efficient cross-device tensor management in TorchRec. reference: [memory snapshot and footprint for non-blocking copy](meta-pytorch#3485) ## Key Changes: 1. **Extended `empty_like` method**: Added an optional `device` parameter to support creating empty KJT structures on a different device. This enables two usage patterns: - Original: Creates empty KJT on the same device, preserving stride/stride_per_key_per_rank with empty data - Device-copy: Creates empty KJT structure on a new device, useful for pre-allocating tensors before async copy operations 2. **New `copy_` method**: Implements an in-place copy operation for KeyedJaggedTensor that: - Copies values, weights, lengths, and offsets from source to destination KJT - Supports non-blocking (async) copies for better performance - Assumes host-side metadata (keys, stride, etc.) is already configured - Handles optional tensors (weights, lengths, offsets) appropriately 3. **Refactored implementation**: Split the original `_kjt_empty_like` logic into: - `_kjt_empty_like_stride`: Preserves original behavior for same-device empty KJT - `_kjt_empty_like_device`: New function for cross-device empty KJT creation These changes enable more efficient device-to-device transfer patterns in distributed training scenarios. {F1983205769} ### Validation: in a prototyping experiments with sparse-data-dist pipeline (TrainPipelineSparseDist), the Memcpy HtoD has similar speed (bandwidth) and the CUDA memory timeline profile, but the reserved memory is 79.7GB vs 74.0GB, showing a 5~6GB benefit. While the input KJT per rank is about 1GB. * trace with direct copy {F1983200620} * trace with inplace copy {F1983200591} * snapshot with direct copy {F1983200644} {F1983200655} * snapshot with inplace copy {F1983200664} {F1983200670} Reviewed By: spmex Differential Revision: D86068070 fbshipit-source-id: 0d1076fd192190b46eed4bda1d4e53b4b245d2a7
1 parent da8924a commit 8edd904

File tree

3 files changed

+537
-17
lines changed

3 files changed

+537
-17
lines changed

torchrec/distributed/test_utils/model_input.py

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,101 @@ class ModelInput(Pipelineable):
3434
idscore_features: Optional[KeyedJaggedTensor]
3535
label: torch.Tensor
3636

37-
def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
38-
return ModelInput(
39-
float_features=self.float_features.to(
40-
device=device, non_blocking=non_blocking
41-
),
42-
idlist_features=(
43-
self.idlist_features.to(device=device, non_blocking=non_blocking)
37+
def to(
38+
self,
39+
device: torch.device,
40+
non_blocking: bool = False,
41+
data_copy_stream: Optional[torch.cuda.streams.Stream] = None,
42+
) -> "ModelInput":
43+
"""
44+
Move ModelInput to the specified device.
45+
46+
Args:
47+
device: Target device to move tensors to.
48+
non_blocking: Whether to perform asynchronous copies.
49+
data_copy_stream: Optional CUDA stream for async data copies. When provided,
50+
tensors are pre-allocated on the target device and copied within this stream.
51+
This enables pipelined data transfers with computation on other streams.
52+
53+
Returns:
54+
ModelInput on the target device.
55+
56+
Example:
57+
# Standard synchronous transfer
58+
batch_gpu = batch_cpu.to(device="cuda")
59+
60+
# Async transfer with dedicated stream
61+
copy_stream = torch.cuda.Stream()
62+
batch_gpu = batch_cpu.to(device="cuda", non_blocking=True, data_copy_stream=copy_stream)
63+
"""
64+
if data_copy_stream is None:
65+
# Standard .to() method
66+
float_features = self.float_features.to(
67+
device=device,
68+
non_blocking=non_blocking,
69+
)
70+
idlist_features = (
71+
self.idlist_features.to(
72+
device=device,
73+
non_blocking=non_blocking,
74+
)
4475
if self.idlist_features is not None
4576
else None
46-
),
47-
idscore_features=(
48-
self.idscore_features.to(device=device, non_blocking=non_blocking)
77+
)
78+
idscore_features = (
79+
self.idscore_features.to(
80+
device=device,
81+
non_blocking=non_blocking,
82+
)
4983
if self.idscore_features is not None
5084
else None
51-
),
52-
label=self.label.to(device=device, non_blocking=non_blocking),
85+
)
86+
label = self.label.to(
87+
device=device,
88+
non_blocking=non_blocking,
89+
)
90+
else:
91+
# Async copy using dedicated stream
92+
current_stream = torch.cuda.current_stream(device)
93+
94+
# Pre-allocate tensors on target device
95+
float_features = torch.empty_like(self.float_features, device=device)
96+
label = torch.empty_like(self.label, device=device)
97+
idlist_features = (
98+
None
99+
if self.idlist_features is None
100+
else KeyedJaggedTensor.empty_like(self.idlist_features, device=device)
101+
)
102+
idscore_features = (
103+
None
104+
if self.idscore_features is None
105+
else KeyedJaggedTensor.empty_like(self.idscore_features, device=device)
106+
)
107+
108+
# Perform async copy in dedicated stream
109+
with data_copy_stream:
110+
# Wait for current stream to finish memory allocation
111+
data_copy_stream.wait_stream(current_stream)
112+
113+
float_features.copy_(self.float_features, non_blocking=non_blocking)
114+
label.copy_(self.label, non_blocking=non_blocking)
115+
if idlist_features is not None:
116+
idlist_features.copy_(
117+
# pyre-ignore[6]: Pyre doesn't understand self.idlist_features is not None here
118+
self.idlist_features,
119+
non_blocking=non_blocking,
120+
)
121+
if idscore_features is not None:
122+
idscore_features.copy_(
123+
# pyre-ignore[6]: Pyre doesn't understand self.idscore_features is not None here
124+
self.idscore_features,
125+
non_blocking=non_blocking,
126+
)
127+
return ModelInput(
128+
float_features=float_features,
129+
idlist_features=idlist_features,
130+
idscore_features=idscore_features,
131+
label=label,
53132
)
54133

55134
def record_stream(self, stream: torch.Stream) -> None:
@@ -299,7 +378,7 @@ def generate(
299378
tables=weighted_tables,
300379
pooling_avg=pooling_avg,
301380
tables_pooling=tables_pooling,
302-
weighted=False, # weighted
381+
weighted=True, # weighted
303382
max_feature_lengths=max_feature_lengths,
304383
use_offsets=use_offsets,
305384
device=device,

torchrec/sparse/jagged_tensor.py

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,8 +1466,9 @@ def _maybe_compute_kjt_to_jt_dict(
14661466

14671467

14681468
@torch.fx.wrap
1469-
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
1469+
def _kjt_empty_like_stride(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14701470
# empty like function fx wrapped, also avoids device hardcoding
1471+
# basically the empty KJT only preserve the stride and stride_per_key_per_rank
14711472
stride, stride_per_key_per_rank = (
14721473
(None, kjt._stride_per_key_per_rank)
14731474
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
@@ -1488,6 +1489,51 @@ def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
14881489
)
14891490

14901491

1492+
@torch.fx.wrap
1493+
def _kjt_empty_like_device(
1494+
kjt: "KeyedJaggedTensor", device: torch.device
1495+
) -> "KeyedJaggedTensor":
1496+
# more likely the torch.Tensor.empty_like function, allocate the memory on device
1497+
stride, stride_per_key_per_rank = (
1498+
(None, kjt._stride_per_key_per_rank)
1499+
if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key()
1500+
else (kjt.stride(), None)
1501+
)
1502+
inverse_indices = kjt._inverse_indices
1503+
return KeyedJaggedTensor(
1504+
keys=kjt.keys(),
1505+
values=torch.empty_like(kjt.values(), device=device),
1506+
weights=(
1507+
None
1508+
if kjt.weights_or_none() is None
1509+
else torch.empty_like(kjt.weights(), device=device)
1510+
),
1511+
lengths=(
1512+
None
1513+
if kjt.lengths_or_none() is None
1514+
else torch.empty_like(kjt.lengths(), device=device)
1515+
),
1516+
offsets=(
1517+
None
1518+
if kjt.offsets_or_none() is None
1519+
else torch.empty_like(kjt.offsets(), device=device)
1520+
),
1521+
stride=stride,
1522+
inverse_indices=(
1523+
None
1524+
if inverse_indices is None
1525+
else (
1526+
inverse_indices[0],
1527+
torch.empty_like(inverse_indices[1], device=device),
1528+
)
1529+
),
1530+
stride_per_key_per_rank=stride_per_key_per_rank,
1531+
stride_per_key=kjt._stride_per_key,
1532+
length_per_key=kjt._length_per_key,
1533+
offset_per_key=kjt._offset_per_key,
1534+
)
1535+
1536+
14911537
def _sum_by_splits(input_list: List[int], splits: List[int]) -> List[int]:
14921538
return [
14931539
sum(input_list[sum(splits[:i]) : sum(splits[:i]) + n])
@@ -1940,17 +1986,83 @@ def empty(
19401986
)
19411987

19421988
@staticmethod
1943-
def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
1989+
def empty_like(
1990+
kjt: "KeyedJaggedTensor",
1991+
device: Optional[torch.device] = None,
1992+
) -> "KeyedJaggedTensor":
19441993
"""
1945-
Constructs an empty KeyedJaggedTensor with the same device and dtypes as the input KeyedJaggedTensor.
1994+
original usage:
1995+
Constructs an empty KeyedJaggedTensor with the same device and dtypes as the input KeyedJaggedTensor.
1996+
this perserves stride/stride_per_key_per_rank but the actual data (values, lengths, etc.) is empty
1997+
1998+
device-copy usage:
1999+
Constructs an empty KeyedJaggedTensor with the empty tensors on the new device
19462000
19472001
Args:
19482002
kjt (KeyedJaggedTensor): input KeyedJaggedTensor.
2003+
device (Optional[torch.device]): device on which the KeyedJaggedTensor will be placed.
19492004
19502005
Returns:
19512006
KeyedJaggedTensor: empty KeyedJaggedTensor.
19522007
"""
1953-
return _kjt_empty_like(kjt)
2008+
if device is None:
2009+
return _kjt_empty_like_stride(kjt)
2010+
else:
2011+
return _kjt_empty_like_device(kjt, device)
2012+
2013+
def copy_(
2014+
self, kjt: "KeyedJaggedTensor", non_blocking: bool = False
2015+
) -> "KeyedJaggedTensor":
2016+
"""
2017+
Copies the values, weights, lengths, and offsets of the input KeyedJaggedTensor to the current KeyedJaggedTensor.
2018+
Assume host-side meta data like the keys, stride, stride_per_key, etc. are already ready.
2019+
2020+
Args:
2021+
kjt (KeyedJaggedTensor): input KeyedJaggedTensor.
2022+
non_blocking (bool): whether to perform the copy asynchronously.
2023+
2024+
Returns:
2025+
KeyedJaggedTensor: copied KeyedJaggedTensor.
2026+
"""
2027+
self._stride_per_key_per_rank = (
2028+
kjt._stride_per_key_per_rank if kjt.variable_stride_per_key() else None
2029+
)
2030+
self._length_per_key = kjt._length_per_key
2031+
self._lengths_offset_per_key = kjt._lengths_offset_per_key
2032+
self._offset_per_key = kjt._offset_per_key
2033+
self._index_per_key = kjt._index_per_key
2034+
self._stride_per_key = kjt._stride_per_key
2035+
self._jt_dict = kjt._jt_dict
2036+
2037+
# tensor in-place copy
2038+
self._values.copy_(kjt._values, non_blocking=non_blocking)
2039+
2040+
weights_self = self._weights
2041+
weights_kjt = kjt._weights
2042+
if weights_self is not None and weights_kjt is not None:
2043+
weights_self.copy_(weights_kjt, non_blocking=non_blocking)
2044+
2045+
lengths_self = self._lengths
2046+
lengths_kjt = kjt._lengths
2047+
if lengths_self is not None and lengths_kjt is not None:
2048+
lengths_self.copy_(lengths_kjt, non_blocking=non_blocking)
2049+
2050+
offsets_self = self._offsets
2051+
offsets_kjt = kjt._offsets
2052+
if offsets_self is not None and offsets_kjt is not None:
2053+
offsets_self.copy_(offsets_kjt, non_blocking=non_blocking)
2054+
2055+
inverse_indices_self = self._inverse_indices
2056+
inverse_indices_kjt = kjt._inverse_indices
2057+
if inverse_indices_self is not None and inverse_indices_kjt is not None:
2058+
self._inverse_indices = (
2059+
inverse_indices_kjt[0],
2060+
inverse_indices_self[1].copy_(
2061+
inverse_indices_kjt[1], non_blocking=non_blocking
2062+
),
2063+
)
2064+
2065+
return self
19542066

19552067
@staticmethod
19562068
def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":

0 commit comments

Comments
 (0)