You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add device parameter to KeyedJaggedTensor.empty_like and copy_ method (meta-pytorch#3510)
Summary:
Pull Request resolved: meta-pytorch#3510
This diff enhances the KeyedJaggedTensor API to support device-aware operations, which is needed for efficient cross-device tensor management in TorchRec.
reference: [memory snapshot and footprint for non-blocking copy](meta-pytorch#3485)
## Key Changes:
1. **Extended `empty_like` method**: Added an optional `device` parameter to support creating empty KJT structures on a different device. This enables two usage patterns:
- Original: Creates empty KJT on the same device, preserving stride/stride_per_key_per_rank with empty data
- Device-copy: Creates empty KJT structure on a new device, useful for pre-allocating tensors before async copy operations
2. **New `copy_` method**: Implements an in-place copy operation for KeyedJaggedTensor that:
- Copies values, weights, lengths, and offsets from source to destination KJT
- Supports non-blocking (async) copies for better performance
- Assumes host-side metadata (keys, stride, etc.) is already configured
- Handles optional tensors (weights, lengths, offsets) appropriately
3. **Refactored implementation**: Split the original `_kjt_empty_like` logic into:
- `_kjt_empty_like_stride`: Preserves original behavior for same-device empty KJT
- `_kjt_empty_like_device`: New function for cross-device empty KJT creation
These changes enable more efficient device-to-device transfer patterns in distributed training scenarios.
{F1983205769}
### Validation:
in a prototyping experiments with sparse-data-dist pipeline (TrainPipelineSparseDist), the Memcpy HtoD has similar speed (bandwidth) and the CUDA memory timeline profile, but the reserved memory is 79.7GB vs 74.0GB, showing a 5~6GB benefit. While the input KJT per rank is about 1GB.
* trace with direct copy
{F1983200620}
* trace with inplace copy
{F1983200591}
* snapshot with direct copy
{F1983200644} {F1983200655}
* snapshot with inplace copy
{F1983200664} {F1983200670}
Reviewed By: spmex
Differential Revision: D86068070
fbshipit-source-id: 0d1076fd192190b46eed4bda1d4e53b4b245d2a7
0 commit comments