-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Kernels] Migrate sampling to WebGPU #737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| this.getTokenLogprob(sampledToken, top_logprobs!), | ||
| ); | ||
| } | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick question: why we cannot use GPU sample kernel when we logprobs is False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, the flow for the logprobs False case involves invoking _attach_multinomial_sampling_func / parallel_sampling_from_prob, which contains i8s that are not supported by WGSL / WebGPU yet. I experimented with enabling some experimental flags at the beginning of the relevant kernels, but I wasn't able to get these to work. One thing I haven't tried yet though is replacing the int8s with some other supported datatype in line 131 here: https://github.com/apache/tvm/blob/26db8bfd7e527198f43f3cc379f404c7513a82ef/python/tvm/relax/backend/gpu_generic/sampling.py#L131C1-L132C1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.. Ideally we could modify those kernels to not use i8s if the backend is WebGPU in TVM
Let's leave a TODO at the start of this else { and somewhere in this PR's description.
I suppose the else branch is the more canonical codepath, since local deployment rarely uses logprob I suppose.
But this PR is great!
| let sampledToken: number; | ||
| if (logprobs) { | ||
| let sampledTokensDevice: tvmjs.Tensor; | ||
| if (logprobs && _hasValue(top_p)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remind me why we add a _hasValue(top_p) here? If a user wants logprobs but does not provide a top_p, it would go to the else branch, and thus not populating the tokenLogprobArray.
Let's set top_p to 1.0 -- the default value at the start when we are pre-processing the sampling parameters. Then we can remove this condition change
Performance Comparison with v0.2.79: Compared performance for "canonical" flows averaged across 20 runs
v0.2.79 performance: ~38.17 decode tokens/s
Post-PR performance: ~38.99 decode tokens/s
Notes: