Skip to content

Commit 47227f4

Browse files
remi-orMcPatate
andauthored
Add prefix sharing to continuous batching (#42094)
* Fix a bug in the CB memory calcuation * Nit in example * Replace _free_blocks with a proper object BlockManager * Removed dead code * Added hasing mechanism (wip) * Added de-duplication * Add de-initialization mechnaism * Add prefix detection * Ensure we always keep 1 token for decode start * Removed some todos and small fix * Update src/transformers/generation/continuous_batching/cache.py Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * Update src/transformers/generation/continuous_batching/continuous_api.py Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * DOCSSSS * Review comments * Style * Added a flag to allow prefix sharing * [IMPORTANT] bug fix for prefix length memoization * Added a test for Cb prefix sharing * Example, start of refactor * End of refactor for example script * Added a do sample arg * Added reporting on prefix sharing * Added a context managr option for CB manager * Nit and style * Review comment from ArthurZucker --------- Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
1 parent 7f9f4d9 commit 47227f4

File tree

7 files changed

+570
-231
lines changed

7 files changed

+570
-231
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 95 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import os
1919
import time
20+
from itertools import cycle
2021
from typing import Optional
2122

2223
import datasets
@@ -29,42 +30,32 @@
2930
from transformers.generation.continuous_batching.requests import logger
3031

3132

32-
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
33-
SLIDING_WINDOW = 0
34-
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
35-
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
36-
SKIP_SPECIAL_TOKENS = False
37-
38-
39-
def generate_simple(
40-
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
33+
def generate_without_cb(
34+
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
4135
) -> dict[str, str]:
42-
attn_impl = {
43-
"sdpa": "sdpa",
44-
"eager": "eager",
45-
"paged_attention": "eager", # TODO: this does not work on AMD docker
46-
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
47-
"kernels-community/flash-attn": "eager",
48-
}[attn_impl]
49-
50-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
36+
# Setup model and tokenizer
37+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
5138
model = model.cuda().eval()
52-
if getattr(model.config, "sliding_window", None) is not None:
53-
model.config.sliding_window = SLIDING_WINDOW
54-
39+
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
40+
model.config.sliding_window = sliding_window
41+
tokenizer = AutoTokenizer.from_pretrained(model_id)
42+
# Generate one by one
5543
decoded_outputs = {}
56-
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
44+
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
5745
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
5846
input_ids = torch.tensor([input_ids]).to("cuda")
59-
# attention_mask = torch.ones_like(input_ids)
60-
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
47+
attention_mask = torch.ones_like(input_ids)
48+
outputs = model.generate(
49+
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
50+
)
6151
generated_tokens = outputs[0][input_ids.shape[1] :]
62-
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
63-
decoded_outputs[key] = decoded_output
52+
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
6453
return decoded_outputs
6554

6655

67-
def setup_metrics():
56+
def maybe_setup_metrics(use_metrics: bool) -> None:
57+
if not use_metrics:
58+
return
6859
try:
6960
from opentelemetry import metrics, trace
7061
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@@ -119,16 +110,14 @@ def batch_generate(
119110
token_count = 0
120111
data = []
121112
for i, request in enumerate(batch_outputs):
122-
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
113+
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
123114
# The key is used to tie back to the output of unbatched generation
124115
key = " ".join(map(str, batch_outputs[request].prompt_ids))
125116
data.append({"input": input_text, "key": key})
126117

127118
# Try to decode the output
128119
try:
129-
output_text = tokenizer.decode(
130-
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
131-
)
120+
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
132121
token_count += len(batch_outputs[request].generated_tokens[1:])
133122
data[-1]["cb_outputs"] = output_text
134123
except Exception as e:
@@ -138,14 +127,7 @@ def batch_generate(
138127

139128
# Display sample if asked
140129
if i < displayed_samples:
141-
if len(output_text) > 0:
142-
print("-" * 20)
143-
print(f"{request} Input: {input_text}")
144-
print(f"{request} Output: {output_text}")
145-
else:
146-
print(f"{request} Input: {input_text}")
147-
print("[WARN]")
148-
print(f"{request} Output was empty!")
130+
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
149131

150132
# Compare with classic generate if asked
151133
if expected_outputs is not None:
@@ -182,83 +164,115 @@ def batch_generate(
182164

183165

184166
if __name__ == "__main__":
185-
# Parse args
186167
parser = argparse.ArgumentParser()
168+
169+
# Continuous batching parameters
187170
parser.add_argument("--num-blocks", "-n", type=int, default=None)
188171
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
189172

173+
# Model parameters
174+
parser.add_argument("--sliding-window", type=int, default=0)
190175
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
176+
177+
# Performance parameters
191178
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
192179
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
193180
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
181+
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
194182

183+
# Benchmark parameters
195184
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
185+
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
186+
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
187+
parser.add_argument("--profile", type=str, default=None)
188+
parser.add_argument("--metrics", action="store_true")
189+
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
190+
191+
# Display parameters
196192
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
197193
parser.add_argument("--log-level", type=str, default="INFO")
198194
parser.add_argument("--output-file", type=str, default=None)
199-
parser.add_argument("--compare", action="store_true")
200-
parser.add_argument("--metrics", action="store_true")
201-
parser.add_argument("--profile", type=str, default=None)
195+
202196
args = parser.parse_args()
203197

204-
# Set log level
205-
logger.setLevel(args.log_level.upper())
198+
# Create model
199+
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
200+
has_system_role = args.sliding_window == 0
201+
202+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
203+
model = model.cuda().eval()
206204

207-
# If turned on, we setup metrics
208-
if args.metrics:
209-
setup_metrics()
205+
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
206+
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
207+
model.config.sliding_window = args.sliding_window
210208

211-
# Set matmul precision if not none
209+
# Set up diagnostics
210+
logger.setLevel(args.log_level.upper())
211+
maybe_setup_metrics(args.metrics)
212+
213+
# Set up performance
212214
if args.matmul_precision != "none":
213215
torch.set_float32_matmul_precision(args.matmul_precision)
214-
# Parse cuda graph argument
215-
if args.cuda_graph is not None:
216-
use_cuda_graph = {
217-
"none": None,
218-
"yes": True, "y": True, "true": True, "t": True, "1": True,
219-
"no": False, "n": False, "false": False, "f": False, "0": False,
220-
}[args.cuda_graph.lower()] # fmt: skip
221-
else:
222-
use_cuda_graph = None
223216

224-
# Prepare model
225-
model = AutoModelForCausalLM.from_pretrained(
226-
MODEL_ID,
227-
attn_implementation=args.attn,
228-
dtype=torch.bfloat16,
229-
)
230-
model = model.cuda().eval()
231-
if getattr(model.config, "sliding_window", None) is not None:
232-
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
233-
model.config.sliding_window = SLIDING_WINDOW
217+
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
218+
use_cuda_graph = {
219+
"none": None, None: None,
220+
"yes": True, "y": True, "true": True, "t": True, "1": True,
221+
"no": False, "n": False, "false": False, "f": False, "0": False,
222+
}[cuda_graph_arg] # fmt: skip
234223

235-
# If turned on, we compile the model
236224
if args.compile:
237225
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
238226

239227
# Prepare tokenizer and dataset
240-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
228+
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
241229

242230
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
243231
dataset = dataset.select(range(args.samples))
244232

245-
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
233+
if args.add_prefix:
234+
possible_prefixes = [
235+
None,
236+
"You are a bot that solves math problems.",
237+
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
238+
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
239+
] # fmt: skip
240+
else:
241+
possible_prefixes = [None]
242+
243+
batched_inputs = []
244+
for item, prefix in zip(dataset, cycle(possible_prefixes)):
245+
messages = []
246+
question = item["question"]
247+
if prefix is not None:
248+
if has_system_role:
249+
messages.append({"role": "system", "content": prefix})
250+
else:
251+
question = prefix + "\n\n" + question
252+
messages.append({"role": "user", "content": question})
253+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
254+
batched_inputs.append(inputs["input_ids"])
246255

247256
# Prepare generation config
248-
generation_config = GenerationConfig(
257+
generation_cfg = GenerationConfig(
249258
max_new_tokens=512,
250259
use_cuda_graph=use_cuda_graph,
251-
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
260+
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
252261
pad_token_id=tokenizer.pad_token_id,
253-
do_sample=not args.compare,
262+
do_sample=args.do_sample,
254263
temperature=0.8,
255264
top_p=0.9,
256265
num_blocks=args.num_blocks,
257266
max_batch_tokens=args.max_batch_tokens,
258267
)
259268

260269
# If we need to compare, we need to generate the reference outputs
261-
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
270+
if args.compare:
271+
expected_outputs = generate_without_cb(
272+
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
273+
)
274+
else:
275+
expected_outputs = None
262276

263277
# If no output file is provided, we pick a name based on the args
264278
if args.output_file is None:
@@ -271,8 +285,8 @@ def batch_generate(
271285
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
272286
batch_generate(
273287
model,
274-
simple_batch_inputs[: min(5, args.samples)],
275-
generation_config,
288+
batched_inputs[: min(5, args.samples)],
289+
generation_cfg,
276290
tokenizer,
277291
displayed_samples=-1,
278292
)
@@ -285,8 +299,8 @@ def batch_generate(
285299
# Run batch generation
286300
gen_time, tok_per_sec = batch_generate(
287301
model,
288-
simple_batch_inputs,
289-
generation_config,
302+
batched_inputs,
303+
generation_cfg,
290304
tokenizer,
291305
displayed_samples=args.displayed,
292306
output_file=args.output_file,
@@ -297,5 +311,5 @@ def batch_generate(
297311
prof.export_chrome_trace(filename)
298312

299313
# Example usage:
300-
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
301-
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
314+
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
315+
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500

0 commit comments

Comments
 (0)