Skip to content

Commit 1dcafb3

Browse files
authored
[Model Runner V2] Support penalties using bin counts (#29703)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent ea3370b commit 1dcafb3

File tree

5 files changed

+280
-14
lines changed

5 files changed

+280
-14
lines changed

vllm/v1/worker/gpu/input_batch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def _post_update_kernel(
341341
idx_mapping_ptr,
342342
num_computed_tokens_ptr,
343343
last_sampled_tokens_ptr,
344+
output_bin_counts_ptr,
345+
output_bin_counts_stride,
344346
sampled_tokens_ptr,
345347
sampled_tokens_stride,
346348
num_sampled_ptr,
@@ -357,6 +359,15 @@ def _post_update_kernel(
357359
)
358360
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
359361

362+
for i in range(num_sampled):
363+
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
364+
token_ptr = (
365+
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
366+
)
367+
count = tl.load(token_ptr)
368+
count += 1
369+
tl.store(token_ptr, count)
370+
360371
query_start = tl.load(query_start_loc_ptr + req_id)
361372
query_end = tl.load(query_start_loc_ptr + req_id + 1)
362373
query_len = query_end - query_start
@@ -374,6 +385,8 @@ def post_update(
374385
num_computed_tokens: torch.Tensor,
375386
# [max_num_reqs]
376387
last_sampled_tokens: torch.Tensor,
388+
# [max_num_reqs, vocab_size]
389+
output_bin_counts: torch.Tensor,
377390
# [num_reqs, num_speculative_steps + 1]
378391
sampled_tokens: torch.Tensor,
379392
# [num_reqs]
@@ -388,6 +401,8 @@ def post_update(
388401
idx_mapping,
389402
num_computed_tokens,
390403
last_sampled_tokens,
404+
output_bin_counts,
405+
output_bin_counts.stride(0),
391406
sampled_tokens,
392407
sampled_tokens.stride(0),
393408
num_sampled,

vllm/v1/worker/gpu/model_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def prepare_inputs(
512512
idx_mapping_np,
513513
num_scheduled_tokens,
514514
query_start_loc_np,
515-
self.req_states.prefill_token_ids,
515+
self.req_states.prefill_token_ids.np,
516516
self.req_states.num_computed_prefill_tokens,
517517
self.input_buffers.input_ids.np,
518518
)
@@ -681,7 +681,7 @@ def compute_prompt_logprobs(
681681
# Handle chunked prompts.
682682
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
683683
is_prompt_chunked = pos_after_step < prompt_lens
684-
prefill_token_ids = self.req_states.prefill_token_ids
684+
prefill_token_ids = self.req_states.prefill_token_ids.np
685685
query_start_loc = self.input_buffers.query_start_loc.np
686686
for i, req_id in enumerate(input_batch.req_ids):
687687
if not needs_prompt_logprobs[i]:
@@ -756,6 +756,7 @@ def postprocess(
756756
input_batch.idx_mapping,
757757
self.req_states.num_computed_tokens,
758758
self.req_states.last_sampled_tokens,
759+
self.req_states.output_bin_counts,
759760
sampled_tokens,
760761
num_sampled,
761762
num_rejected,
@@ -785,7 +786,7 @@ def propose_draft(
785786
idx_mapping_np = input_batch.idx_mapping_np
786787
with async_barrier(self.spec_decode_event):
787788
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
788-
self.req_states.prefill_token_ids[
789+
self.req_states.prefill_token_ids.np[
789790
idx_mapping_np,
790791
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
791792
]
@@ -896,7 +897,7 @@ def execute_model(
896897
# barrier to avoid race conditions.
897898
pos = input_batch.positions[input_batch.logits_indices]
898899
sampling_metadata = self.req_states.make_sampling_metadata(
899-
input_batch.idx_mapping_np, pos
900+
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
900901
)
901902
if input_batch.num_draft_tokens > 0:
902903
sampling_metadata = self.req_states.expand_sampling_metadata(

vllm/v1/worker/gpu/penalties.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
from vllm.triton_utils import tl, triton
6+
from vllm.v1.worker.gpu.states import SamplingMetadata
7+
8+
9+
@triton.jit
10+
def _penalties_kernel(
11+
logits_ptr,
12+
logits_stride,
13+
repetition_penalty_ptr,
14+
frequency_penalty_ptr,
15+
presence_penalty_ptr,
16+
idx_mapping_ptr,
17+
prompt_bin_counts_ptr,
18+
prompt_bin_counts_stride,
19+
output_bin_counts_ptr,
20+
output_bin_counts_stride,
21+
vocab_size,
22+
BLOCK_SIZE: tl.constexpr,
23+
):
24+
batch_idx = tl.program_id(0)
25+
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
26+
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
27+
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
28+
29+
use_rep_penalty = rep_penalty != 1.0
30+
use_freq_penalty = freq_penalty != 0.0
31+
use_pres_penalty = pres_penalty != 0.0
32+
if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
33+
# No penalties to apply. Early return.
34+
return
35+
36+
block_idx = tl.program_id(1)
37+
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
38+
mask = block < vocab_size
39+
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
40+
logits = logits.to(tl.float32)
41+
42+
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
43+
output_bin_counts = tl.load(
44+
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
45+
mask=mask,
46+
)
47+
48+
# Apply repetition penalties.
49+
if use_rep_penalty:
50+
prompt_bin_counts = tl.load(
51+
prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
52+
mask=mask,
53+
)
54+
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
55+
scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
56+
# If logits are positive, divide by penalty, otherwise multiply by penalty.
57+
scale = tl.where(logits > 0, 1.0 / scale, scale)
58+
logits *= scale
59+
60+
# Apply frequency penalties.
61+
logits -= freq_penalty * output_bin_counts
62+
# Apply presence penalties.
63+
logits -= pres_penalty * (output_bin_counts > 0)
64+
# Store back to logits.
65+
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
66+
67+
68+
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
69+
num_reqs, vocab_size = logits.shape
70+
BLOCK_SIZE = 8192
71+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
72+
_penalties_kernel[(num_reqs, num_blocks)](
73+
logits,
74+
logits.stride(0),
75+
sampling_metadata.repetition_penalty,
76+
sampling_metadata.frequency_penalty,
77+
sampling_metadata.presence_penalty,
78+
sampling_metadata.idx_mapping,
79+
sampling_metadata.prompt_bin_counts,
80+
sampling_metadata.prompt_bin_counts.stride(0),
81+
sampling_metadata.output_bin_counts,
82+
sampling_metadata.output_bin_counts.stride(0),
83+
vocab_size,
84+
BLOCK_SIZE=BLOCK_SIZE,
85+
)

vllm/v1/worker/gpu/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.triton_utils import tl, triton
99
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
1010
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
11+
from vllm.v1.worker.gpu.penalties import apply_penalties
1112
from vllm.v1.worker.gpu.states import SamplingMetadata
1213

1314

@@ -65,6 +66,8 @@ def sample(
6566
logits = apply_top_k_top_p(
6667
logits, sampling_metadata.top_k, sampling_metadata.top_p
6768
)
69+
# Apply penalties in place.
70+
apply_penalties(logits, sampling_metadata)
6871

6972
sampled = gumbel_sample(
7073
logits,

0 commit comments

Comments
 (0)