Skip to content

Commit 23ba198

Browse files
jerrychenhfmhelf-intel
authored andcommitted
Spec decode warmup support (vllm-project#624)
GAUDISW-242931 Because currently spec decode flatten the spec decode tokens into [batch_size * num_tokens, 1], we can warmup the decode shapes as it was. The thing changed is the maximum batch_size we should warmup in the configuration because the real batch size is batch_size * num_tokens which is num_tokens (1 + num_speculative_tokens) times of original batch size. The thing to care in the warmup is the draft token (and block) space for the proposing process in eagle. We need to leave out the num_speculative_tokens space to use by propose for eagle. Other care needs to be taken (already done in the PR of support num_speculative_tokens > 1) is warmup will be run in compile only mode without the real computation happening. So the operations for prepare_attn_metadata in the drafter which depends on the real position values must be done on CPU) Another issue of handling no spec decode tokens for decode phase has already been handled vllm-project#593 --------- Signed-off-by: Chen Haifeng <haifeng.chen@intel.com>
1 parent efa7c83 commit 23ba198

File tree

2 files changed

+101
-11
lines changed

2 files changed

+101
-11
lines changed

vllm_gaudi/extension/bucketing/common.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,30 @@ class HPUBucketingManager():
3939
prompt_buckets: List[Tuple[int, int, int]] = []
4040
decode_buckets: List[Tuple[int, int, int]] = []
4141
unified_buckets: List[Tuple[int, int, int]] = []
42+
# Seed buckets are the buckets originally generated from bucketing configuration
43+
# Spec decode may automatically add new buckets based on the seed buckets
44+
seed_decode_buckets: List[Tuple[int, int, int]] = None
4245
initialized = False
4346

4447
def __new__(cls, *args, **kwargs):
4548
if not cls._instance:
4649
cls._instance = super(HPUBucketingManager, cls).__new__(cls)
4750
return cls._instance
4851

49-
def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
52+
def initialize(self,
53+
max_num_seqs,
54+
max_num_prefill_seqs,
55+
block_size,
56+
max_num_batched_tokens,
57+
max_model_len,
58+
num_speculative_tokens=0):
5059
self.max_num_seqs = max_num_seqs
5160
self.max_num_prefill_seqs = max_num_prefill_seqs
5261
self.block_size = block_size
5362
self.max_num_batched_tokens = max_num_batched_tokens
5463
self.num_hpu_blocks = None
5564
self.max_model_len = max_model_len
65+
self.num_speculative_tokens = num_speculative_tokens
5666
self.initialized = True
5767
self.fallback_bs_base_step = 2
5868
self.fallback_seq_base_step = 32
@@ -189,6 +199,12 @@ def generate_decode_buckets(self):
189199
self.max_num_seqs, self.max_num_prefill_seqs,
190200
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
191201
buckets_from_file)
202+
if self.num_speculative_tokens:
203+
# The existing buckets are used as seed decode buckets
204+
self.seed_decode_buckets = self.decode_buckets
205+
# More buckets are added automatically for spec decode
206+
self.decode_buckets = self.generate_spec_decode_buckets(self.decode_buckets)
207+
192208
self.log_generate_info(False)
193209
else:
194210
logger().info("Bucketing is off - skipping decode buckets generation")
@@ -232,8 +248,14 @@ def find_prompt_bucket(self, batch_size, seq_len, ctx=0):
232248
return found_bucket
233249
return (batch_size, seq_len, ctx)
234250

235-
def find_decode_bucket(self, batch_size, num_blocks):
251+
def find_decode_bucket(self, batch_size, num_blocks, seed_buckets: bool = False):
236252
if self.initialized:
253+
if seed_buckets and self.seed_decode_buckets is not None:
254+
found_bucket = find_equal_or_closest_greater_config(self.seed_decode_buckets,
255+
(batch_size, 1, num_blocks))
256+
if found_bucket is not None:
257+
return found_bucket
258+
237259
found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks))
238260
if found_bucket is None:
239261
new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks)
@@ -260,6 +282,46 @@ def get_max_prompt_shape(self):
260282
return max(b[1] for b in self.prompt_buckets) \
261283
if len(self.prompt_buckets) > 0 else self.max_model_len
262284

285+
def generate_spec_decode_buckets(self, seed_decode_buckets):
286+
max_model_len = self.max_model_len
287+
block_size = self.block_size
288+
289+
def no_corrections(bs, query, ctx):
290+
return (bs, query, ctx)
291+
292+
def correct_for_max_model_len(bs, query, ctx):
293+
return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size)))
294+
295+
def get_corrector(use_contiguous_pa):
296+
if use_contiguous_pa:
297+
return no_corrections
298+
else:
299+
return correct_for_max_model_len
300+
301+
use_contiguous_pa = get_config().use_contiguous_pa
302+
corrector = get_corrector(use_contiguous_pa)
303+
304+
# If spec decode enabled, generate buckets for batch_size * (1 + num_speculative_tokens)
305+
num_tokens = 1 + self.num_speculative_tokens
306+
buckets = set()
307+
for bucket in seed_decode_buckets:
308+
buckets.add(bucket)
309+
bs, query, ctx = bucket
310+
spec_decode_bs = bs * num_tokens
311+
if spec_decode_bs <= ctx:
312+
# Add a bucket with (batch_size * num_tokens, query, ctx)
313+
buckets.add(corrector(spec_decode_bs, query, ctx))
314+
# Add a bucket with (batch_size * num_tokens, query, ctx * num_tokens)
315+
buckets.add(corrector(spec_decode_bs, query, ctx * num_tokens))
316+
317+
# Log the new generated spec decode buckets
318+
new_buckets = sorted(buckets - set(seed_decode_buckets))
319+
msg = (f"Generated {len(new_buckets)} "
320+
f"spec decode buckets [bs, query, num_blocks]: {list(new_buckets)}")
321+
logger().info(msg)
322+
323+
return sorted(buckets)
324+
263325
@classmethod
264326
def get_instance(cls):
265327
"""

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -891,11 +891,13 @@ def __init__(
891891
else self.max_prefill_batch_size
892892
if self.enable_bucketing:
893893
logger.info("Bucketing is ON.")
894+
num_speculative_tokens = self.speculative_config.num_speculative_tokens if self.speculative_config else 0
894895
self.bucketing_manager.initialize(max_num_seqs=self.max_num_seqs,
895896
max_num_prefill_seqs=max_num_prefill_seqs,
896897
block_size=self.block_size,
897898
max_num_batched_tokens=self.max_num_batched_tokens,
898-
max_model_len=self.max_model_len)
899+
max_model_len=self.max_model_len,
900+
num_speculative_tokens=num_speculative_tokens)
899901
self.graphed_buckets: set[Any] = set()
900902
self.graphed_multimodal_buckets: set[Any] = set()
901903
else:
@@ -2031,15 +2033,18 @@ def _create_decode_input_data(self,
20312033
# but also kvs for the current token
20322034
num_blocks = np.ceil((context_lens + 1) / self.block_size).astype(np.int32).tolist()
20332035

2036+
num_tokens_per_req = num_scheduled_tokens[:num_decodes]
2037+
num_tokens = max(num_tokens_per_req)
2038+
# Spec decode to use seed buckets to get padded batch size
2039+
seek_buckets = bool(num_tokens > 1)
2040+
20342041
# PAD FOR STATIC SHAPES.
20352042
padded_batch_size: int
2036-
padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks))[0]
2043+
padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks), seek_buckets)[0]
20372044

20382045
# dp aware padding
20392046
padded_batch_size += self.get_dp_padding(padded_batch_size)
20402047

2041-
num_tokens_per_req = num_scheduled_tokens[:num_decodes]
2042-
num_tokens = max(num_tokens_per_req)
20432048
total_num_scheduled_tokens = sum(num_tokens_per_req)
20442049
num_tokens_per_req = num_tokens_per_req + [0] * (padded_batch_size - num_decodes)
20452050

@@ -3513,7 +3518,9 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
35133518
for req_id in self.input_batch.req_ids[:num_reqs]:
35143519
req_state = self.requests[req_id]
35153520
i = self.input_batch.req_id_to_index[req_id]
3516-
seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id])
3521+
# Cannot use num_computed_tokens + num_scheduled_tokens here
3522+
# as it may include rejected spec decode tokens
3523+
seq_len = self.input_batch.num_tokens_no_spec[i]
35173524
token_ids = postprocessed_sampled_token_ids[i]
35183525
num_tokens = len(token_ids)
35193526
self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids
@@ -4098,7 +4105,16 @@ def _add_dummy_request(self,
40984105
scheduled_tokens,
40994106
is_prompt,
41004107
block_id=0):
4101-
num_blocks = round_up(total_tokens, self.block_size) // self.block_size
4108+
# Spec decode: blocks should include look ahead tokens (eagle)
4109+
total_tokens_for_blocks = total_tokens
4110+
if self.speculative_config and self.speculative_config.use_eagle():
4111+
# Consider the block space for draft tokens to propose
4112+
total_tokens_for_blocks += self.speculative_config.num_speculative_tokens
4113+
# Check the limit of the max model length
4114+
if total_tokens_for_blocks > self.max_model_len:
4115+
total_tokens_for_blocks = self.max_model_len
4116+
4117+
num_blocks = round_up(total_tokens_for_blocks, self.block_size) // self.block_size
41024118
prompt_token_ids = list(range(total_tokens))
41034119

41044120
req_id = f'{len(requests)}'
@@ -4177,14 +4193,20 @@ def _add_dummy_unified_request(self, requests, is_prompt, is_unique, block_num,
41774193
requests.append(req)
41784194
scheduled_tokens[req_id] = num_scheduled_tokens
41794195

4180-
@staticmethod
4181-
def _generate_seq_lengths(num_samples, num_blocks, block_size):
4196+
def _generate_seq_lengths(self, num_samples, num_blocks, block_size):
41824197
assert num_samples <= num_blocks
41834198
blocks = [num_blocks // num_samples] * num_samples
41844199
missing_blocks = num_blocks - sum(blocks)
41854200
for i in range(missing_blocks):
41864201
blocks[i] += 1
4187-
seq_lengths = [b * block_size - 1 for b in blocks]
4202+
4203+
# Leave space for the output token and draft tokens to propose
4204+
num_lookahead_tokens = 1
4205+
if self.speculative_config and self.speculative_config.use_eagle():
4206+
# Consider the token space for draft tokens to propose
4207+
# The draft tokens for eagle consumes block table space
4208+
num_lookahead_tokens += self.speculative_config.num_speculative_tokens
4209+
seq_lengths = [b * block_size - num_lookahead_tokens for b in blocks]
41884210
return seq_lengths
41894211

41904212
def distribute_sum_evenly(self, total_sum, max_length):
@@ -4324,6 +4346,12 @@ def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg):
43244346
prompt_num_blocks)
43254347
for _ in range(prompt_bs):
43264348
for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks):
4349+
if self.speculative_config and self.speculative_config.use_eagle():
4350+
# Leave the block space for draft tokens to propose
4351+
# The draft tokens for eagle consumes block table space
4352+
num_speculative_tokens = self.speculative_config.num_speculative_tokens
4353+
tokens -= num_speculative_tokens
4354+
prompt_query_len -= num_speculative_tokens
43274355
self._add_dummy_request(requests,
43284356
scheduled_tokens,
43294357
num_computed_tokens=(context_len * self.block_size),

0 commit comments

Comments
 (0)