From ecbf0086e4cd549fc8b34ad78735d7bc9880f3db Mon Sep 17 00:00:00 2001 From: Chen Haifeng Date: Tue, 25 Nov 2025 09:57:53 +0800 Subject: [PATCH 1/4] Spec decode warmup support Signed-off-by: Chen Haifeng --- vllm_gaudi/v1/worker/hpu_model_runner.py | 33 ++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index df42af89f..b7433b0ef 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -3504,7 +3504,9 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu for req_id in self.input_batch.req_ids[:num_reqs]: req_state = self.requests[req_id] i = self.input_batch.req_id_to_index[req_id] - seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) + # Cannot use num_computed_tokens + num_scheduled_tokens here + # as it may include rejected spec decode tokens + seq_len = self.input_batch.num_tokens_no_spec[i] token_ids = postprocessed_sampled_token_ids[i] num_tokens = len(token_ids) self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids @@ -4089,7 +4091,16 @@ def _add_dummy_request(self, scheduled_tokens, is_prompt, block_id=0): - num_blocks = round_up(total_tokens, self.block_size) // self.block_size + # Spec decode: blocks should include look ahead tokens (eagle) + total_tokens_for_blocks = total_tokens + if self.speculative_config and self.speculative_config.use_eagle(): + # Consider the block space for draft tokens to propose + total_tokens_for_blocks += self.speculative_config.num_speculative_tokens + # Check the limit of the max model length + if total_tokens_for_blocks > self.max_model_len: + total_tokens_for_blocks = self.max_model_len + + num_blocks = round_up(total_tokens_for_blocks, self.block_size) // self.block_size prompt_token_ids = list(range(total_tokens)) req_id = f'{len(requests)}' @@ -4168,14 +4179,20 @@ def _add_dummy_unified_request(self, requests, is_prompt, is_unique, block_num, requests.append(req) scheduled_tokens[req_id] = num_scheduled_tokens - @staticmethod - def _generate_seq_lengths(num_samples, num_blocks, block_size): + def _generate_seq_lengths(self, num_samples, num_blocks, block_size): assert num_samples <= num_blocks blocks = [num_blocks // num_samples] * num_samples missing_blocks = num_blocks - sum(blocks) for i in range(missing_blocks): blocks[i] += 1 - seq_lengths = [b * block_size - 1 for b in blocks] + + # Leave space for the output token and draft tokens to propose + num_lookahead_tokens = 1 + if self.speculative_config and self.speculative_config.use_eagle(): + # Consider the token space for draft tokens to propose + # The draft tokens for eagle consumes block table space + num_lookahead_tokens += self.speculative_config.num_speculative_tokens + seq_lengths = [b * block_size - num_lookahead_tokens for b in blocks] return seq_lengths def distribute_sum_evenly(self, total_sum, max_length): @@ -4315,6 +4332,12 @@ def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg): prompt_num_blocks) for _ in range(prompt_bs): for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks): + if self.speculative_config and self.speculative_config.use_eagle(): + # Leave the block space for draft tokens to propose + # The draft tokens for eagle consumes block table space + num_speculative_tokens = self.speculative_config.num_speculative_tokens + tokens -= num_speculative_tokens + prompt_query_len -= num_speculative_tokens self._add_dummy_request(requests, scheduled_tokens, num_computed_tokens=(context_len * self.block_size), From e5f7747c9da6110bc33b291b6e26c73ea8be1610 Mon Sep 17 00:00:00 2001 From: Chen Haifeng Date: Tue, 2 Dec 2025 15:40:22 +0800 Subject: [PATCH 2/4] Automatically generate new buckets for spec decode based on seed buckets Signed-off-by: Chen Haifeng --- vllm_gaudi/extension/bucketing/common.py | 55 +++++++++++++++++++++++- vllm_gaudi/v1/worker/hpu_model_runner.py | 13 ++++-- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 15fa063e5..d9cf88da4 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -39,6 +39,9 @@ class HPUBucketingManager(): prompt_buckets: List[Tuple[int, int, int]] = [] decode_buckets: List[Tuple[int, int, int]] = [] unified_buckets: List[Tuple[int, int, int]] = [] + # Seed buckets are the buckets originally generated from bucketing configuration + # Spec decode may automatically add new buckets based on the seed buckets + seed_decode_buckets: List[Tuple[int, int, int]] = None initialized = False def __new__(cls, *args, **kwargs): @@ -46,13 +49,15 @@ def __new__(cls, *args, **kwargs): cls._instance = super(HPUBucketingManager, cls).__new__(cls) return cls._instance - def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len): + def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len, + num_speculative_tokens=0): self.max_num_seqs = max_num_seqs self.max_num_prefill_seqs = max_num_prefill_seqs self.block_size = block_size self.max_num_batched_tokens = max_num_batched_tokens self.num_hpu_blocks = None self.max_model_len = max_model_len + self.num_speculative_tokens = num_speculative_tokens self.initialized = True self.fallback_bs_base_step = 2 self.fallback_seq_base_step = 32 @@ -189,6 +194,12 @@ def generate_decode_buckets(self): self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, buckets_from_file) + if self.num_speculative_tokens: + # The existing buckets are used as seed decode buckets + self.seed_decode_buckets = self.decode_buckets + # More buckets are added automatically for spec decode + self.decode_buckets = self.generate_spec_decode_buckets(self.decode_buckets) + self.log_generate_info(False) else: logger().info("Bucketing is off - skipping decode buckets generation") @@ -232,8 +243,14 @@ def find_prompt_bucket(self, batch_size, seq_len, ctx=0): return found_bucket return (batch_size, seq_len, ctx) - def find_decode_bucket(self, batch_size, num_blocks): + def find_decode_bucket(self, batch_size, num_blocks, seed_buckets: bool = False): if self.initialized: + if seed_buckets and self.seed_decode_buckets is not None: + found_bucket = find_equal_or_closest_greater_config(self.seed_decode_buckets, + (batch_size, 1, num_blocks)) + if found_bucket is not None: + return found_bucket + found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks)) if found_bucket is None: new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks) @@ -260,6 +277,40 @@ def get_max_prompt_shape(self): return max(b[1] for b in self.prompt_buckets) \ if len(self.prompt_buckets) > 0 else self.max_model_len + def generate_spec_decode_buckets(self, seed_decode_buckets): + max_model_len = self.max_model_len + block_size = self.block_size + + def no_corrections(bs, query, ctx): + return (bs, query, ctx) + + def correct_for_max_model_len(bs, query, ctx): + return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size))) + + def get_corrector(use_contiguous_pa): + if use_contiguous_pa: + return no_corrections + else: + return correct_for_max_model_len + + use_contiguous_pa = get_config().use_contiguous_pa + corrector = get_corrector(use_contiguous_pa) + + # If spec decode enabled, generate buckets for batch_size * (1 + num_speculative_tokens) + num_tokens = 1 + self.num_speculative_tokens + buckets = set() + for bucket in seed_decode_buckets: + buckets.add(bucket) + bs, query, ctx = bucket + spec_decode_bs = bs * num_tokens + if spec_decode_bs <= ctx: + # Add a bucket with (batch_size * num_tokens, query, ctx) + buckets.add(corrector(spec_decode_bs, query, ctx)) + # Add a bucket with (batch_size * num_tokens, query, ctx * num_tokens) + buckets.add(corrector(spec_decode_bs, query, ctx * num_tokens)) + + return sorted(buckets) + @classmethod def get_instance(cls): """ diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index b7433b0ef..1f9fef8aa 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -891,11 +891,13 @@ def __init__( else self.max_prefill_batch_size if self.enable_bucketing: logger.info("Bucketing is ON.") + num_speculative_tokens = self.speculative_config.num_speculative_tokens if self.speculative_config else 0 self.bucketing_manager.initialize(max_num_seqs=self.max_num_seqs, max_num_prefill_seqs=max_num_prefill_seqs, block_size=self.block_size, max_num_batched_tokens=self.max_num_batched_tokens, - max_model_len=self.max_model_len) + max_model_len=self.max_model_len, + num_speculative_tokens=num_speculative_tokens) self.graphed_buckets: set[Any] = set() self.graphed_multimodal_buckets: set[Any] = set() else: @@ -2031,15 +2033,18 @@ def _create_decode_input_data(self, # but also kvs for the current token num_blocks = np.ceil((context_lens + 1) / self.block_size).astype(np.int32).tolist() + num_tokens_per_req = num_scheduled_tokens[:num_decodes] + num_tokens = max(num_tokens_per_req) + # Spec decode to use seed buckets to get padded batch size + seek_buckets = True if num_tokens > 1 else False + # PAD FOR STATIC SHAPES. padded_batch_size: int - padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks))[0] + padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks), seek_buckets)[0] # dp aware padding padded_batch_size += self.get_dp_padding(padded_batch_size) - num_tokens_per_req = num_scheduled_tokens[:num_decodes] - num_tokens = max(num_tokens_per_req) total_num_scheduled_tokens = sum(num_tokens_per_req) num_tokens_per_req = num_tokens_per_req + [0] * (padded_batch_size - num_decodes) From d8dfe24eecb9f148e94b0498e3dc9a6852654f28 Mon Sep 17 00:00:00 2001 From: Chen Haifeng Date: Tue, 2 Dec 2025 17:57:15 +0800 Subject: [PATCH 3/4] Fix formatting Signed-off-by: Chen Haifeng --- vllm_gaudi/extension/bucketing/common.py | 7 ++++++- vllm_gaudi/v1/worker/hpu_model_runner.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index d9cf88da4..8239b0fff 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -49,7 +49,12 @@ def __new__(cls, *args, **kwargs): cls._instance = super(HPUBucketingManager, cls).__new__(cls) return cls._instance - def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len, + def initialize(self, + max_num_seqs, + max_num_prefill_seqs, + block_size, + max_num_batched_tokens, + max_model_len, num_speculative_tokens=0): self.max_num_seqs = max_num_seqs self.max_num_prefill_seqs = max_num_prefill_seqs diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1f9fef8aa..40134df36 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2036,7 +2036,7 @@ def _create_decode_input_data(self, num_tokens_per_req = num_scheduled_tokens[:num_decodes] num_tokens = max(num_tokens_per_req) # Spec decode to use seed buckets to get padded batch size - seek_buckets = True if num_tokens > 1 else False + seek_buckets = bool(num_tokens > 1) # PAD FOR STATIC SHAPES. padded_batch_size: int From 7061d1b0ab0b072921f6f8de20c93e93cfe7384c Mon Sep 17 00:00:00 2001 From: Chen Haifeng Date: Wed, 3 Dec 2025 10:55:35 +0800 Subject: [PATCH 4/4] Add info log for new generated spec decode buckets Signed-off-by: Chen Haifeng --- vllm_gaudi/extension/bucketing/common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 8239b0fff..4c94f9f28 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -314,6 +314,12 @@ def get_corrector(use_contiguous_pa): # Add a bucket with (batch_size * num_tokens, query, ctx * num_tokens) buckets.add(corrector(spec_decode_bs, query, ctx * num_tokens)) + # Log the new generated spec decode buckets + new_buckets = sorted(buckets - set(seed_decode_buckets)) + msg = (f"Generated {len(new_buckets)} " + f"spec decode buckets [bs, query, num_blocks]: {list(new_buckets)}") + logger().info(msg) + return sorted(buckets) @classmethod