Skip to content

Commit 2c19d96

Browse files
[Spec Decode] Integrate Suffix Decoding from Arctic Inference (vllm-project#25784)
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
1 parent 4bc400f commit 2c19d96

File tree

8 files changed

+304
-11
lines changed

8 files changed

+304
-11
lines changed

docs/features/spec_decode.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,46 @@ matching n-grams in the prompt. For more information read [this thread.](https:/
130130
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
131131
```
132132

133+
## Speculating using Suffix Decoding
134+
135+
The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)).
136+
137+
Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates.
138+
139+
Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts.
140+
141+
!!! tip "Install Arctic Inference"
142+
Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`.
143+
144+
!!! tip "Suffix Decoding Speculative Tokens"
145+
Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default).
146+
147+
??? code
148+
149+
```python
150+
from vllm import LLM, SamplingParams
151+
152+
prompts = [
153+
"The future of AI is",
154+
]
155+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
156+
157+
llm = LLM(
158+
model="facebook/opt-6.7b",
159+
tensor_parallel_size=1,
160+
speculative_config={
161+
"method": "suffix",
162+
"num_speculative_tokens": 32,
163+
},
164+
)
165+
outputs = llm.generate(prompts, sampling_params)
166+
167+
for output in outputs:
168+
prompt = output.prompt
169+
generated_text = output.outputs[0].text
170+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
171+
```
172+
133173
## Speculating using MLP speculators
134174

135175
The following code configures vLLM to use speculative decoding where proposals are generated by

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9
4848
genai_perf==0.0.8
4949
tritonclient==2.51.0
5050

51+
arctic-inference == 0.1.0 # Required for suffix decoding test
5152
numba == 0.61.2 # Required for N-gram speculative decoding
5253
numpy
5354
runai-model-streamer[s3,gcs]==0.15.0

requirements/test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ anyio==4.6.2.post1
4040
# via
4141
# httpx
4242
# starlette
43+
arctic-inference==0.1.0
44+
# via -r requirements/test.in
4345
argcomplete==3.5.1
4446
# via datamodel-code-generator
4547
arrow==1.3.0

tests/v1/e2e/test_spec_decode.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,23 @@ def model_name():
7575
return "meta-llama/Llama-3.1-8B-Instruct"
7676

7777

78-
def test_ngram_correctness(
78+
@pytest.mark.parametrize(
79+
"speculative_config",
80+
[
81+
{
82+
"method": "ngram",
83+
"prompt_lookup_max": 5,
84+
"prompt_lookup_min": 3,
85+
"num_speculative_tokens": 3,
86+
},
87+
{
88+
"method": "suffix",
89+
"suffix_decoding_max_spec_factor": 2.0,
90+
},
91+
],
92+
)
93+
def test_ngram_and_suffix_correctness(
94+
speculative_config: dict,
7995
monkeypatch: pytest.MonkeyPatch,
8096
sampling_config: SamplingParams,
8197
model_name: str,
@@ -94,12 +110,7 @@ def test_ngram_correctness(
94110

95111
spec_llm = LLM(
96112
model=model_name,
97-
speculative_config={
98-
"method": "ngram",
99-
"prompt_lookup_max": 5,
100-
"prompt_lookup_min": 3,
101-
"num_speculative_tokens": 3,
102-
},
113+
speculative_config=speculative_config,
103114
max_model_len=1024,
104115
)
105116
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
@@ -121,6 +132,66 @@ def test_ngram_correctness(
121132
cleanup_dist_env_and_memory()
122133

123134

135+
def test_suffix_decoding_acceptance(
136+
monkeypatch: pytest.MonkeyPatch,
137+
sampling_config: SamplingParams,
138+
model_name: str,
139+
):
140+
"""
141+
Check that suffix decoding caching takes effect and improves acceptance
142+
lengths and acceptance rates over multiple runs of the same prompts.
143+
"""
144+
test_prompts = get_test_prompts(mm_enabled=False)
145+
146+
spec_llm = LLM(
147+
model=model_name,
148+
speculative_config={
149+
"method": "suffix",
150+
"suffix_decoding_max_spec_factor": 2.0,
151+
"suffix_decoding_max_cached_requests": 1000,
152+
},
153+
max_model_len=1024,
154+
disable_log_stats=False,
155+
)
156+
157+
# Run several times and check that the accepted tokens increase.
158+
spec_llm.chat(test_prompts, sampling_config)
159+
num_draft = []
160+
num_accept = []
161+
for i in range(10): # Run multiple times to warm up the cache.
162+
spec_llm.chat(test_prompts, sampling_config)
163+
# Collect draft and acceptance stats.
164+
metrics = spec_llm.get_metrics()
165+
for metric in metrics:
166+
if metric.name == "vllm:spec_decode_num_draft_tokens":
167+
num_draft.append(metric.value)
168+
if metric.name == "vllm:spec_decode_num_accepted_tokens":
169+
num_accept.append(metric.value)
170+
171+
# Calculate the acceptance rates for the first and last runs.
172+
first_accept_tokens = num_accept[0]
173+
first_draft_tokens = num_draft[0]
174+
first_accept_rate = first_accept_tokens / first_draft_tokens
175+
176+
# Take the diff since the stats are cumulative.
177+
last_accept_tokens = num_accept[-1] - num_accept[-2]
178+
last_draft_tokens = num_draft[-1] - num_draft[-2]
179+
last_accept_rate = last_accept_tokens / last_draft_tokens
180+
181+
# Expect the acceptance length to improve.
182+
assert first_accept_tokens < last_accept_tokens
183+
184+
# Expect the acceptance rate to improve.
185+
assert first_accept_rate < last_accept_rate
186+
187+
# Heuristic: expect at least 85% acceptance rate at the end.
188+
assert last_accept_rate > 0.85
189+
190+
del spec_llm
191+
torch.cuda.empty_cache()
192+
cleanup_dist_env_and_memory()
193+
194+
124195
@pytest.mark.parametrize(
125196
"model_path",
126197
[

vllm/config/speculative.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.config.parallel import ParallelConfig
1313
from vllm.config.utils import config
1414
from vllm.logger import init_logger
15-
from vllm.utils.import_utils import LazyLoader
15+
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
1616

1717
if TYPE_CHECKING:
1818
from transformers import PretrainedConfig
@@ -42,6 +42,7 @@
4242
"mimo_mtp",
4343
"longcat_flash_mtp",
4444
"mtp",
45+
"suffix",
4546
]
4647
MTP_MODEL_TYPES = (
4748
"deepseek_mtp",
@@ -129,6 +130,27 @@ class SpeculativeConfig:
129130
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
130131
"""The parallel configuration for the draft model initialized internal."""
131132

133+
# Suffix decoding configuration
134+
suffix_decoding_max_tree_depth: int = 24
135+
"""The maximum depth of the suffix decoding global and prompt trees. The
136+
tree depth limits the sum of the prefix match and speculation lengths."""
137+
138+
suffix_decoding_max_cached_requests: int = 10000
139+
"""The maximum number of requests to cache in the global suffix tree. If
140+
exceeded, will trigger eviction in FIFO order. If set to 0, the global
141+
suffix tree is disabled and past responses are not cached (prompt trees
142+
are still used)."""
143+
144+
suffix_decoding_max_spec_factor: float = 1.0
145+
"""The maximum spec factor for suffix decoding. The spec factor controls
146+
speculation lengths based on the prefix match length: max_spec_tokens =
147+
max_spec_factor * prefix_match_length."""
148+
149+
suffix_decoding_min_token_prob: float = 0.1
150+
"""The minimum token probability for suffix decoding. Will only speculate
151+
tokens with estimated probability (based on frequency counts) greater than
152+
or equal to this value."""
153+
132154
def compute_hash(self) -> str:
133155
"""
134156
WARNING: Whenever a new field is added to this config,
@@ -235,6 +257,8 @@ def __post_init__(self):
235257
self.quantization = self.target_model_config.quantization
236258
elif self.method in ("ngram", "[ngram]"):
237259
self.model = "ngram"
260+
elif self.method == "suffix":
261+
self.model = "suffix"
238262
else:
239263
raise ValueError(
240264
"num_speculative_tokens was provided but without speculative model."
@@ -282,6 +306,8 @@ def __post_init__(self):
282306
# draft related config as None here.
283307
self.draft_model_config = self.target_model_config
284308
self.draft_parallel_config = self.target_parallel_config
309+
elif self.method == "suffix":
310+
self._validate_suffix_decoding()
285311
else:
286312
self.prompt_lookup_max = 0
287313
self.prompt_lookup_min = 0
@@ -430,6 +456,42 @@ def __post_init__(self):
430456
)
431457
return self
432458

459+
def _validate_suffix_decoding(self):
460+
if not has_arctic_inference():
461+
raise ImportError(
462+
"Arctic Inference is required for suffix decoding. "
463+
"Install via `pip install arctic-inference==0.1.0`."
464+
)
465+
if self.num_speculative_tokens is None:
466+
# Suffix decoding decides the actual number of speculative tokens
467+
# dynamically and treats num_speculative_tokens as a maximum limit.
468+
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
469+
logger.warning(
470+
"Defaulted num_speculative_tokens to %s for suffix decoding.",
471+
self.num_speculative_tokens,
472+
)
473+
# Validate values
474+
if self.suffix_decoding_max_tree_depth < 1:
475+
raise ValueError(
476+
f"suffix_decoding_max_tree_depth="
477+
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
478+
)
479+
if self.suffix_decoding_max_cached_requests < 0:
480+
raise ValueError(
481+
f"suffix_decoding_max_cached_requests="
482+
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
483+
)
484+
if self.suffix_decoding_max_spec_factor < 0:
485+
raise ValueError(
486+
f"suffix_decoding_max_spec_factor="
487+
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
488+
)
489+
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
490+
raise ValueError(
491+
f"suffix_decoding_min_token_prob="
492+
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
493+
)
494+
433495
@staticmethod
434496
def _maybe_override_draft_max_model_len(
435497
speculative_max_model_len: int | None,
@@ -582,6 +644,6 @@ def use_eagle(self) -> bool:
582644

583645
def __repr__(self) -> str:
584646
method = self.method
585-
model = None if method == "ngram" else self.draft_model_config.model
647+
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
586648
num_spec_tokens = self.num_speculative_tokens
587649
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"

vllm/utils/import_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,9 @@ def has_triton_kernels() -> bool:
403403
def has_tilelang() -> bool:
404404
"""Whether the optional `tilelang` package is available."""
405405
return _has_module("tilelang")
406+
407+
408+
def has_arctic_inference() -> bool:
409+
"""Whether the optional `arctic_inference` package is available."""
410+
411+
return _has_module("arctic_inference")
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.config import VllmConfig
4+
from vllm.v1.worker.gpu_input_batch import InputBatch
5+
6+
7+
class SuffixDecodingProposer:
8+
"""
9+
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
10+
This class imports and uses the official implementation from Arctic Inference
11+
(https://github.com/snowflakedb/ArcticInference).
12+
"""
13+
14+
def __init__(self, vllm_config: VllmConfig):
15+
config = vllm_config.speculative_config
16+
self.num_speculative_tokens = config.num_speculative_tokens
17+
self.max_tree_depth = config.suffix_decoding_max_tree_depth
18+
self.max_spec_factor = config.suffix_decoding_max_spec_factor
19+
self.min_token_prob = config.suffix_decoding_min_token_prob
20+
self.max_model_len = vllm_config.model_config.max_model_len
21+
22+
# Lazy import to avoid error when Suffix Decoding is not used.
23+
from arctic_inference.suffix_decoding import SuffixDecodingCache
24+
25+
# Initialize and empty cache. This object will take care of caching request
26+
# outputs, evicting old requests, and manages the per-prompt suffix trees.
27+
self.suffix_cache = SuffixDecodingCache(
28+
max_tree_depth=config.suffix_decoding_max_tree_depth,
29+
max_cached_requests=config.suffix_decoding_max_cached_requests,
30+
)
31+
32+
def propose(
33+
self,
34+
input_batch: InputBatch,
35+
sampled_token_ids: list[list[int]],
36+
) -> list[list[int]]:
37+
"""
38+
Propose speculative tokens for each request in the input batch. Suffix Decoding
39+
will speculate a dynamic number of tokens for each request every decoding step,
40+
so each entry in the returned list may have different lengths.
41+
"""
42+
draft_token_ids: list[list[int]] = []
43+
for i, sampled_ids in enumerate(sampled_token_ids):
44+
if not sampled_ids:
45+
# Skip speculative decoding for partial prefills.
46+
draft_token_ids.append([])
47+
continue
48+
49+
# Skip requests that require sampling parameters that are not
50+
# supported with speculative decoding.
51+
req_id = input_batch.req_ids[i]
52+
if req_id in input_batch.spec_decode_unsupported_reqs:
53+
draft_token_ids.append([])
54+
continue
55+
56+
num_tokens = input_batch.num_tokens_no_spec[i]
57+
if num_tokens >= self.max_model_len:
58+
# Skip requests that have already reached the max model length.
59+
draft_token_ids.append([])
60+
continue
61+
62+
index = input_batch.req_id_to_index[req_id]
63+
if req_id not in self.suffix_cache.active_requests:
64+
if req_id in self.suffix_cache.cached_requests:
65+
# Reset the suffix cache for this request.
66+
self.suffix_cache.evict_cached_response(req_id)
67+
num_prompt_tokens = input_batch.num_prompt_tokens[index]
68+
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
69+
# Start a new request, this will build the suffix tree for that prompt.
70+
self.suffix_cache.start_request(req_id, prompt_token_ids)
71+
72+
# Append the newly sampled ids to the suffix cache for this request.
73+
self.suffix_cache.add_active_response(req_id, sampled_ids)
74+
75+
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
76+
# we extract the pattern from the end of the input.
77+
start = max(0, num_tokens - self.max_tree_depth)
78+
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
79+
draft = self.suffix_cache.speculate(
80+
req_id,
81+
pattern,
82+
max_spec_tokens=min(
83+
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
84+
),
85+
max_spec_factor=self.max_spec_factor,
86+
min_token_prob=self.min_token_prob,
87+
)
88+
89+
draft_token_ids.append(draft.token_ids)
90+
91+
# Stop requests that were not seen in the input batch.
92+
for req_id in (
93+
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
94+
):
95+
self.suffix_cache.stop_request(req_id)
96+
97+
return draft_token_ids
98+
99+
def load_model(self, *args, **kwargs):
100+
# No model to load.
101+
pass

0 commit comments

Comments
 (0)