|
12 | 12 | from vllm.config.parallel import ParallelConfig |
13 | 13 | from vllm.config.utils import config |
14 | 14 | from vllm.logger import init_logger |
15 | | -from vllm.utils.import_utils import LazyLoader |
| 15 | +from vllm.utils.import_utils import LazyLoader, has_arctic_inference |
16 | 16 |
|
17 | 17 | if TYPE_CHECKING: |
18 | 18 | from transformers import PretrainedConfig |
|
42 | 42 | "mimo_mtp", |
43 | 43 | "longcat_flash_mtp", |
44 | 44 | "mtp", |
| 45 | + "suffix", |
45 | 46 | ] |
46 | 47 | MTP_MODEL_TYPES = ( |
47 | 48 | "deepseek_mtp", |
@@ -129,6 +130,27 @@ class SpeculativeConfig: |
129 | 130 | draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore |
130 | 131 | """The parallel configuration for the draft model initialized internal.""" |
131 | 132 |
|
| 133 | + # Suffix decoding configuration |
| 134 | + suffix_decoding_max_tree_depth: int = 24 |
| 135 | + """The maximum depth of the suffix decoding global and prompt trees. The |
| 136 | + tree depth limits the sum of the prefix match and speculation lengths.""" |
| 137 | + |
| 138 | + suffix_decoding_max_cached_requests: int = 10000 |
| 139 | + """The maximum number of requests to cache in the global suffix tree. If |
| 140 | + exceeded, will trigger eviction in FIFO order. If set to 0, the global |
| 141 | + suffix tree is disabled and past responses are not cached (prompt trees |
| 142 | + are still used).""" |
| 143 | + |
| 144 | + suffix_decoding_max_spec_factor: float = 1.0 |
| 145 | + """The maximum spec factor for suffix decoding. The spec factor controls |
| 146 | + speculation lengths based on the prefix match length: max_spec_tokens = |
| 147 | + max_spec_factor * prefix_match_length.""" |
| 148 | + |
| 149 | + suffix_decoding_min_token_prob: float = 0.1 |
| 150 | + """The minimum token probability for suffix decoding. Will only speculate |
| 151 | + tokens with estimated probability (based on frequency counts) greater than |
| 152 | + or equal to this value.""" |
| 153 | + |
132 | 154 | def compute_hash(self) -> str: |
133 | 155 | """ |
134 | 156 | WARNING: Whenever a new field is added to this config, |
@@ -235,6 +257,8 @@ def __post_init__(self): |
235 | 257 | self.quantization = self.target_model_config.quantization |
236 | 258 | elif self.method in ("ngram", "[ngram]"): |
237 | 259 | self.model = "ngram" |
| 260 | + elif self.method == "suffix": |
| 261 | + self.model = "suffix" |
238 | 262 | else: |
239 | 263 | raise ValueError( |
240 | 264 | "num_speculative_tokens was provided but without speculative model." |
@@ -282,6 +306,8 @@ def __post_init__(self): |
282 | 306 | # draft related config as None here. |
283 | 307 | self.draft_model_config = self.target_model_config |
284 | 308 | self.draft_parallel_config = self.target_parallel_config |
| 309 | + elif self.method == "suffix": |
| 310 | + self._validate_suffix_decoding() |
285 | 311 | else: |
286 | 312 | self.prompt_lookup_max = 0 |
287 | 313 | self.prompt_lookup_min = 0 |
@@ -430,6 +456,42 @@ def __post_init__(self): |
430 | 456 | ) |
431 | 457 | return self |
432 | 458 |
|
| 459 | + def _validate_suffix_decoding(self): |
| 460 | + if not has_arctic_inference(): |
| 461 | + raise ImportError( |
| 462 | + "Arctic Inference is required for suffix decoding. " |
| 463 | + "Install via `pip install arctic-inference==0.1.0`." |
| 464 | + ) |
| 465 | + if self.num_speculative_tokens is None: |
| 466 | + # Suffix decoding decides the actual number of speculative tokens |
| 467 | + # dynamically and treats num_speculative_tokens as a maximum limit. |
| 468 | + self.num_speculative_tokens = self.suffix_decoding_max_tree_depth |
| 469 | + logger.warning( |
| 470 | + "Defaulted num_speculative_tokens to %s for suffix decoding.", |
| 471 | + self.num_speculative_tokens, |
| 472 | + ) |
| 473 | + # Validate values |
| 474 | + if self.suffix_decoding_max_tree_depth < 1: |
| 475 | + raise ValueError( |
| 476 | + f"suffix_decoding_max_tree_depth=" |
| 477 | + f"{self.suffix_decoding_max_tree_depth} must be >= 1" |
| 478 | + ) |
| 479 | + if self.suffix_decoding_max_cached_requests < 0: |
| 480 | + raise ValueError( |
| 481 | + f"suffix_decoding_max_cached_requests=" |
| 482 | + f"{self.suffix_decoding_max_cached_requests} must be >= 0" |
| 483 | + ) |
| 484 | + if self.suffix_decoding_max_spec_factor < 0: |
| 485 | + raise ValueError( |
| 486 | + f"suffix_decoding_max_spec_factor=" |
| 487 | + f"{self.suffix_decoding_max_spec_factor} must be >= 0" |
| 488 | + ) |
| 489 | + if not 0 <= self.suffix_decoding_min_token_prob <= 1: |
| 490 | + raise ValueError( |
| 491 | + f"suffix_decoding_min_token_prob=" |
| 492 | + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" |
| 493 | + ) |
| 494 | + |
433 | 495 | @staticmethod |
434 | 496 | def _maybe_override_draft_max_model_len( |
435 | 497 | speculative_max_model_len: int | None, |
@@ -582,6 +644,6 @@ def use_eagle(self) -> bool: |
582 | 644 |
|
583 | 645 | def __repr__(self) -> str: |
584 | 646 | method = self.method |
585 | | - model = None if method == "ngram" else self.draft_model_config.model |
| 647 | + model = None if method in ("ngram", "suffix") else self.draft_model_config.model |
586 | 648 | num_spec_tokens = self.num_speculative_tokens |
587 | 649 | return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" |
0 commit comments