Skip to content

Conversation

@akaashrp
Copy link
Contributor

@akaashrp akaashrp commented Nov 2, 2025

Performance Comparison with v0.2.79: Compared performance for "canonical" flows averaged across 20 runs

  • No logit_bias
  • No logitProcessor
  • Applied frequency, presence, and repetition penalties
  • Use logprobs
  • No top_logprobs

v0.2.79 performance: ~38.17 decode tokens/s
Post-PR performance: ~38.99 decode tokens/s

Notes:

  1. The minimal performance improvement is likely due to kernel launch overheads. Specifically, we need to call three kernels to perform sampling (fsoftmaxWithTemperature, fargsortProbs, fSampleWithTopP).
  2. This will likely scale better for simultaneous sampling from multiple sequences.

@akaashrp akaashrp requested a review from CharlieFRuan November 2, 2025 05:57
this.getTokenLogprob(sampledToken, top_logprobs!),
);
}
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question: why we cannot use GPU sample kernel when we logprobs is False?

Copy link
Contributor Author

@akaashrp akaashrp Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the flow for the logprobs False case involves invoking _attach_multinomial_sampling_func / parallel_sampling_from_prob, which contains i8s that are not supported by WGSL / WebGPU yet. I experimented with enabling some experimental flags at the beginning of the relevant kernels, but I wasn't able to get these to work. One thing I haven't tried yet though is replacing the int8s with some other supported datatype in line 131 here: https://github.com/apache/tvm/blob/26db8bfd7e527198f43f3cc379f404c7513a82ef/python/tvm/relax/backend/gpu_generic/sampling.py#L131C1-L132C1.

Copy link
Member

@CharlieFRuan CharlieFRuan Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. Ideally we could modify those kernels to not use i8s if the backend is WebGPU in TVM

Let's leave a TODO at the start of this else { and somewhere in this PR's description.

I suppose the else branch is the more canonical codepath, since local deployment rarely uses logprob I suppose.

But this PR is great!

let sampledToken: number;
if (logprobs) {
let sampledTokensDevice: tvmjs.Tensor;
if (logprobs && _hasValue(top_p)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remind me why we add a _hasValue(top_p) here? If a user wants logprobs but does not provide a top_p, it would go to the else branch, and thus not populating the tokenLogprobArray.

Let's set top_p to 1.0 -- the default value at the start when we are pre-processing the sampling parameters. Then we can remove this condition change

@CharlieFRuan CharlieFRuan mentioned this pull request Nov 11, 2025
14 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants