-
Notifications
You must be signed in to change notification settings - Fork 77
Get remote tensors inside @helion.kernel
#1122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I believe we need to update test_examples_dist.py as well, but Helion distributed CI currently has a bug causing the test error to not surface. I will land a PR to fix the bug and then we can rebase this PR
|
@kwen2501 in case adding lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901
lib.define(
"get_remote_tensors(Tensor x, str group_name) -> Tensor[]"
)
@torch.library.impl(lib, "get_remote_tensors", "CUDA")
def _get_remote_tensors_default(
local: torch.Tensor,
group_name: str
):
hdl = torch.distributed._symmetric_memory.rendezvous(local, group_name)
return tuple(
hdl.get_remote_tensor(peer, local.size(), local.dtype) for peer in range(hdl.world_size)
)
@torch.library.impl(lib, "get_remote_tensors", "Meta")
def _get_remote_tensors_meta(
local: torch.Tensor,
group_name: str
):
# TODO: correct world size
world_size = torch.distributed.get_world_size()
return (local,) * world_size |
9ec6b6e to
7fda194
Compare
|
@kwen2501 I'll rebase this PR so that it has the distributed CI error propagation fix. Thanks! |
7fda194 to
47f5294
Compare
|
@yf225 What torch version does CI use? |
yes it uses torch nightly - should be able to pick it up very soon |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot @kwen2501 !
To support use case in pytorch/helion#1122, i.e. ``` @helion.kernel def foo( x: Tensor, group_name: str ): x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name) for t in x_remotes: ... ```` Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function: ``` hdl = rendezvous(tensor) remote_tensors = tuple( hdl.get_remote_tensor(peer, ...) for peer in range(world_size) ) ``` The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail. This PR wraps the above functionality as an op: ``` lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") ``` so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous. Pull Request resolved: #167779 Approved by: https://github.com/yf225
f342333 to
1db590a
Compare
To support use case in pytorch/helion#1122, i.e. ``` @helion.kernel def foo( x: Tensor, group_name: str ): x_remotes = torch.ops.symm_mem.get_remote_tensors(x, group_name) for t in x_remotes: ... ```` Helion uses fake tensor to trace a program, thus we cannot use the following code in a Helion function: ``` hdl = rendezvous(tensor) remote_tensors = tuple( hdl.get_remote_tensor(peer, ...) for peer in range(world_size) ) ``` The reason is that when `tensor` is fake, the returned `hdl` is None, thus any subsequent call on it will fail. This PR wraps the above functionality as an op: ``` lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") ``` so that things like `hdl` is not exposed to Helion. The op also provides a `meta` implementation so that Helion can trace it without actually running the rendezvous. Pull Request resolved: pytorch#167779 Approved by: https://github.com/yf225
|
Hi @kwen2501! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Instead of user passing a tuple of tensors into the kernel.
We get the tuple of remote tensors by calling
torch.ops.symm_mem.get_remote_tensorsin the CPU part of the Helion function.This op is yet to be upstreamed on PyTorch side. Naively, it is nothing but:
The "Meta" impl is necessary because Helion seems to traces the function in Fake mode.