Commit 7530009
[SymmMem] op to get remote tensors (pytorch#167779)
To support use case in pytorch/helion#1122, i.e.
```
@helion.kernel
def foo(
x: Tensor,
group_name: str
):
x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name)
for t in x_remotes:
...
````
Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function:
```
hdl = rendezvous(tensor)
remote_tensors = tuple(
hdl.get_remote_tensor(peer, ...) for peer in range(world_size)
)
```
The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail.
This PR wraps the above functionality as an op:
```
lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]")
```
so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous.
Pull Request resolved: pytorch#167779
Approved by: https://github.com/yf2251 parent 2aa3a06 commit 7530009
File tree
2 files changed
+48
-0
lines changed- test/distributed
- torch/distributed/_symmetric_memory
2 files changed
+48
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
208 | 208 | | |
209 | 209 | | |
210 | 210 | | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
211 | 226 | | |
212 | 227 | | |
213 | 228 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
465 | 465 | | |
466 | 466 | | |
467 | 467 | | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
468 | 501 | | |
469 | 502 | | |
470 | 503 | | |
| |||
0 commit comments