Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,30 @@ class HPUBucketingManager():
prompt_buckets: List[Tuple[int, int, int]] = []
decode_buckets: List[Tuple[int, int, int]] = []
unified_buckets: List[Tuple[int, int, int]] = []
# Seed buckets are the buckets originally generated from bucketing configuration
# Spec decode may automatically add new buckets based on the seed buckets
seed_decode_buckets: List[Tuple[int, int, int]] = None
initialized = False

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(HPUBucketingManager, cls).__new__(cls)
return cls._instance

def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
def initialize(self,
max_num_seqs,
max_num_prefill_seqs,
block_size,
max_num_batched_tokens,
max_model_len,
num_speculative_tokens=0):
self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self.num_hpu_blocks = None
self.max_model_len = max_model_len
self.num_speculative_tokens = num_speculative_tokens
self.initialized = True
self.fallback_bs_base_step = 2
self.fallback_seq_base_step = 32
Expand Down Expand Up @@ -189,6 +199,12 @@ def generate_decode_buckets(self):
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file)
if self.num_speculative_tokens:
# The existing buckets are used as seed decode buckets
self.seed_decode_buckets = self.decode_buckets
# More buckets are added automatically for spec decode
self.decode_buckets = self.generate_spec_decode_buckets(self.decode_buckets)

self.log_generate_info(False)
else:
logger().info("Bucketing is off - skipping decode buckets generation")
Expand Down Expand Up @@ -232,8 +248,14 @@ def find_prompt_bucket(self, batch_size, seq_len, ctx=0):
return found_bucket
return (batch_size, seq_len, ctx)

def find_decode_bucket(self, batch_size, num_blocks):
def find_decode_bucket(self, batch_size, num_blocks, seed_buckets: bool = False):
if self.initialized:
if seed_buckets and self.seed_decode_buckets is not None:
found_bucket = find_equal_or_closest_greater_config(self.seed_decode_buckets,
(batch_size, 1, num_blocks))
if found_bucket is not None:
return found_bucket

found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks))
if found_bucket is None:
new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks)
Expand All @@ -260,6 +282,46 @@ def get_max_prompt_shape(self):
return max(b[1] for b in self.prompt_buckets) \
if len(self.prompt_buckets) > 0 else self.max_model_len

def generate_spec_decode_buckets(self, seed_decode_buckets):
max_model_len = self.max_model_len
block_size = self.block_size

def no_corrections(bs, query, ctx):
return (bs, query, ctx)

def correct_for_max_model_len(bs, query, ctx):
return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size)))

def get_corrector(use_contiguous_pa):
if use_contiguous_pa:
return no_corrections
else:
return correct_for_max_model_len

use_contiguous_pa = get_config().use_contiguous_pa
corrector = get_corrector(use_contiguous_pa)

# If spec decode enabled, generate buckets for batch_size * (1 + num_speculative_tokens)
num_tokens = 1 + self.num_speculative_tokens
buckets = set()
for bucket in seed_decode_buckets:
buckets.add(bucket)
bs, query, ctx = bucket
spec_decode_bs = bs * num_tokens
if spec_decode_bs <= ctx:
# Add a bucket with (batch_size * num_tokens, query, ctx)
buckets.add(corrector(spec_decode_bs, query, ctx))
# Add a bucket with (batch_size * num_tokens, query, ctx * num_tokens)
buckets.add(corrector(spec_decode_bs, query, ctx * num_tokens))

# Log the new generated spec decode buckets
new_buckets = sorted(buckets - set(seed_decode_buckets))
msg = (f"Generated {len(new_buckets)} "
f"spec decode buckets [bs, query, num_blocks]: {list(new_buckets)}")
logger().info(msg)

return sorted(buckets)

@classmethod
def get_instance(cls):
"""
Expand Down
46 changes: 37 additions & 9 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,11 +891,13 @@ def __init__(
else self.max_prefill_batch_size
if self.enable_bucketing:
logger.info("Bucketing is ON.")
num_speculative_tokens = self.speculative_config.num_speculative_tokens if self.speculative_config else 0
self.bucketing_manager.initialize(max_num_seqs=self.max_num_seqs,
max_num_prefill_seqs=max_num_prefill_seqs,
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len)
max_model_len=self.max_model_len,
num_speculative_tokens=num_speculative_tokens)
self.graphed_buckets: set[Any] = set()
self.graphed_multimodal_buckets: set[Any] = set()
else:
Expand Down Expand Up @@ -2031,15 +2033,18 @@ def _create_decode_input_data(self,
# but also kvs for the current token
num_blocks = np.ceil((context_lens + 1) / self.block_size).astype(np.int32).tolist()

num_tokens_per_req = num_scheduled_tokens[:num_decodes]
num_tokens = max(num_tokens_per_req)
# Spec decode to use seed buckets to get padded batch size
seek_buckets = bool(num_tokens > 1)

# PAD FOR STATIC SHAPES.
padded_batch_size: int
padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks))[0]
padded_batch_size = self.bucketing_manager.find_decode_bucket(num_decodes, sum(num_blocks), seek_buckets)[0]

# dp aware padding
padded_batch_size += self.get_dp_padding(padded_batch_size)

num_tokens_per_req = num_scheduled_tokens[:num_decodes]
num_tokens = max(num_tokens_per_req)
total_num_scheduled_tokens = sum(num_tokens_per_req)
num_tokens_per_req = num_tokens_per_req + [0] * (padded_batch_size - num_decodes)

Expand Down Expand Up @@ -3504,7 +3509,9 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
for req_id in self.input_batch.req_ids[:num_reqs]:
req_state = self.requests[req_id]
i = self.input_batch.req_id_to_index[req_id]
seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id])
# Cannot use num_computed_tokens + num_scheduled_tokens here
# as it may include rejected spec decode tokens
seq_len = self.input_batch.num_tokens_no_spec[i]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

token_ids = postprocessed_sampled_token_ids[i]
num_tokens = len(token_ids)
self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids
Expand Down Expand Up @@ -4089,7 +4096,16 @@ def _add_dummy_request(self,
scheduled_tokens,
is_prompt,
block_id=0):
num_blocks = round_up(total_tokens, self.block_size) // self.block_size
# Spec decode: blocks should include look ahead tokens (eagle)
total_tokens_for_blocks = total_tokens
if self.speculative_config and self.speculative_config.use_eagle():
# Consider the block space for draft tokens to propose
total_tokens_for_blocks += self.speculative_config.num_speculative_tokens
# Check the limit of the max model length
if total_tokens_for_blocks > self.max_model_len:
total_tokens_for_blocks = self.max_model_len

num_blocks = round_up(total_tokens_for_blocks, self.block_size) // self.block_size
prompt_token_ids = list(range(total_tokens))

req_id = f'{len(requests)}'
Expand Down Expand Up @@ -4168,14 +4184,20 @@ def _add_dummy_unified_request(self, requests, is_prompt, is_unique, block_num,
requests.append(req)
scheduled_tokens[req_id] = num_scheduled_tokens

@staticmethod
def _generate_seq_lengths(num_samples, num_blocks, block_size):
def _generate_seq_lengths(self, num_samples, num_blocks, block_size):
assert num_samples <= num_blocks
blocks = [num_blocks // num_samples] * num_samples
missing_blocks = num_blocks - sum(blocks)
for i in range(missing_blocks):
blocks[i] += 1
seq_lengths = [b * block_size - 1 for b in blocks]

# Leave space for the output token and draft tokens to propose
num_lookahead_tokens = 1
if self.speculative_config and self.speculative_config.use_eagle():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here

# Consider the token space for draft tokens to propose
# The draft tokens for eagle consumes block table space
num_lookahead_tokens += self.speculative_config.num_speculative_tokens
seq_lengths = [b * block_size - num_lookahead_tokens for b in blocks]
return seq_lengths

def distribute_sum_evenly(self, total_sum, max_length):
Expand Down Expand Up @@ -4315,6 +4337,12 @@ def _prepare_dummy_scenario(self, prompt_cfg, decode_cfg):
prompt_num_blocks)
for _ in range(prompt_bs):
for tokens, context_len in zip(prompt_total_tokens, prompt_num_context_blocks):
if self.speculative_config and self.speculative_config.use_eagle():
# Leave the block space for draft tokens to propose
# The draft tokens for eagle consumes block table space
num_speculative_tokens = self.speculative_config.num_speculative_tokens
tokens -= num_speculative_tokens
prompt_query_len -= num_speculative_tokens
self._add_dummy_request(requests,
scheduled_tokens,
num_computed_tokens=(context_len * self.block_size),
Expand Down