Skip to content

Commit 7530009

Browse files
kwen2501Silv3S
authored andcommitted
[SymmMem] op to get remote tensors (pytorch#167779)
To support use case in pytorch/helion#1122, i.e. ``` @helion.kernel def foo( x: Tensor, group_name: str ): x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name) for t in x_remotes: ... ```` Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function: ``` hdl = rendezvous(tensor) remote_tensors = tuple( hdl.get_remote_tensor(peer, ...) for peer in range(world_size) ) ``` The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail. This PR wraps the above functionality as an op: ``` lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") ``` so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous. Pull Request resolved: pytorch#167779 Approved by: https://github.com/yf225
1 parent 2aa3a06 commit 7530009

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

test/distributed/test_nvshmem.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,21 @@ def test_get_remote_tensor(self) -> None:
208208
)
209209
self.assertEqual(y, expected)
210210

211+
def test_get_remote_tensors(self) -> None:
212+
"""
213+
Get all remote tensors
214+
"""
215+
self._init_device()
216+
group_name = dist.group.WORLD.group_name
217+
symm_mem.enable_symm_mem_for_group(group_name)
218+
219+
my_tensor = symm_mem.empty(1, device=self.device).fill_(self.rank)
220+
remote_tensors = torch.ops.symm_mem.get_remote_tensors(my_tensor, group_name)
221+
dist.barrier()
222+
223+
for peer, tensor in enumerate(remote_tensors):
224+
self.assertEqual(tensor, peer)
225+
211226
@skipIfRocm
212227
def test_nvshmem_put(self) -> None:
213228
self._init_device()

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,39 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
465465
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
466466
)
467467

468+
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
469+
"""
470+
Given a local tensor and a group name, return a tuple of tensors that are
471+
symmetric on other devices. The returned tensors are ordered by rank IDs. The
472+
length of the tuple equals to the size of the group.
473+
474+
Note: this API works only when `world_within_direct_access()` returns True, i.e.
475+
only when the group is within NVLink domain or similar. It does not work across
476+
network interfaces.
477+
"""
478+
479+
480+
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
481+
def _get_remote_tensors_default(
482+
local: torch.Tensor, group_name: str
483+
) -> tuple[torch.Tensor, ...]:
484+
hdl = rendezvous(local, group_name)
485+
if hdl is None:
486+
raise ValueError("Tensor is not allocated from Symmetric Memory")
487+
488+
return tuple(
489+
hdl.get_remote_tensor(peer, local.size(), local.dtype)
490+
for peer in range(hdl.world_size)
491+
)
492+
493+
494+
@torch.library.impl(lib, "get_remote_tensors", "Meta")
495+
def _get_remote_tensors_meta(
496+
local: torch.Tensor, group_name: str
497+
) -> tuple[torch.Tensor, ...]:
498+
group = c10d._resolve_process_group(group_name)
499+
return tuple(torch.empty_like(local) for _ in range(group.size()))
500+
468501

469502
class _ScaleMode(Enum):
470503
UNSCALED = "unscaled"

0 commit comments

Comments
 (0)