1717import json
1818import os
1919import time
20+ from itertools import cycle
2021from typing import Optional
2122
2223import datasets
2930from transformers .generation .continuous_batching .requests import logger
3031
3132
32- # MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
33- SLIDING_WINDOW = 0
34- MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
35- FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
36- SKIP_SPECIAL_TOKENS = False
37-
38-
39- def generate_simple (
40- attn_impl : str , simple_batch_inputs : list [int ], generation_config : GenerationConfig
33+ def generate_without_cb (
34+ model_id : str , sliding_window : int , attn_impl : str , batched_inputs : list [int ], generation_config : GenerationConfig
4135) -> dict [str , str ]:
42- attn_impl = {
43- "sdpa" : "sdpa" ,
44- "eager" : "eager" ,
45- "paged_attention" : "eager" , # TODO: this does not work on AMD docker
46- "flash_paged" : "flash_attention_2" , # TODO: this does not work on AMD docker
47- "kernels-community/flash-attn" : "eager" ,
48- }[attn_impl ]
49-
50- model = AutoModelForCausalLM .from_pretrained (MODEL_ID , dtype = torch .bfloat16 , attn_implementation = attn_impl )
36+ # Setup model and tokenizer
37+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = torch .bfloat16 , attn_implementation = attn_impl )
5138 model = model .cuda ().eval ()
52- if getattr (model .config , "sliding_window" , None ) is not None :
53- model .config .sliding_window = SLIDING_WINDOW
54-
39+ if sliding_window > 0 and getattr (model .config , "sliding_window" , None ) is not None :
40+ model .config .sliding_window = sliding_window
41+ tokenizer = AutoTokenizer .from_pretrained (model_id )
42+ # Generate one by one
5543 decoded_outputs = {}
56- for input_ids in tqdm (simple_batch_inputs , desc = "Generating outputs without CB" ):
44+ for input_ids in tqdm (batched_inputs , desc = "Generating outputs without CB" ):
5745 key = " " .join (map (str , input_ids )) # This will be used to identify the output after batched generation
5846 input_ids = torch .tensor ([input_ids ]).to ("cuda" )
59- # attention_mask = torch.ones_like(input_ids)
60- outputs = model .generate (input_ids , generation_config = generation_config , use_model_defaults = False )
47+ attention_mask = torch .ones_like (input_ids )
48+ outputs = model .generate (
49+ input_ids , attention_mask = attention_mask , generation_config = generation_config , use_model_defaults = False
50+ )
6151 generated_tokens = outputs [0 ][input_ids .shape [1 ] :]
62- decoded_output = tokenizer .decode (generated_tokens , skip_special_tokens = SKIP_SPECIAL_TOKENS )
63- decoded_outputs [key ] = decoded_output
52+ decoded_outputs [key ] = tokenizer .decode (generated_tokens , skip_special_tokens = False )
6453 return decoded_outputs
6554
6655
67- def setup_metrics ():
56+ def maybe_setup_metrics (use_metrics : bool ) -> None :
57+ if not use_metrics :
58+ return
6859 try :
6960 from opentelemetry import metrics , trace
7061 from opentelemetry .exporter .otlp .proto .http .metric_exporter import OTLPMetricExporter
@@ -119,16 +110,14 @@ def batch_generate(
119110 token_count = 0
120111 data = []
121112 for i , request in enumerate (batch_outputs ):
122- input_text = tokenizer .decode (batch_outputs [request ].prompt_ids , skip_special_tokens = SKIP_SPECIAL_TOKENS )
113+ input_text = tokenizer .decode (batch_outputs [request ].prompt_ids , skip_special_tokens = False )
123114 # The key is used to tie back to the output of unbatched generation
124115 key = " " .join (map (str , batch_outputs [request ].prompt_ids ))
125116 data .append ({"input" : input_text , "key" : key })
126117
127118 # Try to decode the output
128119 try :
129- output_text = tokenizer .decode (
130- batch_outputs [request ].generated_tokens , skip_special_tokens = SKIP_SPECIAL_TOKENS
131- )
120+ output_text = tokenizer .decode (batch_outputs [request ].generated_tokens , skip_special_tokens = False )
132121 token_count += len (batch_outputs [request ].generated_tokens [1 :])
133122 data [- 1 ]["cb_outputs" ] = output_text
134123 except Exception as e :
@@ -138,14 +127,7 @@ def batch_generate(
138127
139128 # Display sample if asked
140129 if i < displayed_samples :
141- if len (output_text ) > 0 :
142- print ("-" * 20 )
143- print (f"{ request } Input: { input_text } " )
144- print (f"{ request } Output: { output_text } " )
145- else :
146- print (f"{ request } Input: { input_text } " )
147- print ("[WARN]" )
148- print (f"{ request } Output was empty!" )
130+ print ("-" * 20 , f"{ request } Input: { input_text } " , f"{ request } Output: { output_text } " , sep = "\n " )
149131
150132 # Compare with classic generate if asked
151133 if expected_outputs is not None :
@@ -182,83 +164,115 @@ def batch_generate(
182164
183165
184166if __name__ == "__main__" :
185- # Parse args
186167 parser = argparse .ArgumentParser ()
168+
169+ # Continuous batching parameters
187170 parser .add_argument ("--num-blocks" , "-n" , type = int , default = None )
188171 parser .add_argument ("--max-batch-tokens" , "-b" , type = int , default = None )
189172
173+ # Model parameters
174+ parser .add_argument ("--sliding-window" , type = int , default = 0 )
190175 parser .add_argument ("--attn" , type = str , default = "kernels-community/flash-attn" , help = "Attention implementation" )
176+
177+ # Performance parameters
191178 parser .add_argument ("--matmul-precision" , "-mp" , type = str , default = "high" ) # set to "none" to disable
192179 parser .add_argument ("--cuda-graph" , "-cg" , help = "Use cuda graphs" , type = str , default = None )
193180 parser .add_argument ("--compile" , action = "store_true" , help = "Compile the model using torch.compile" )
181+ parser .add_argument ("--do-sample" , action = "store_true" , help = "Activate sampling" )
194182
183+ # Benchmark parameters
195184 parser .add_argument ("--samples" , type = int , default = 500 , help = "Number of samples to generate" )
185+ parser .add_argument ("--add-prefix" , action = "store_true" , help = "Add a prefix to the samples" )
186+ parser .add_argument ("--compare" , action = "store_true" , help = "Compare CB generation with classic generate" )
187+ parser .add_argument ("--profile" , type = str , default = None )
188+ parser .add_argument ("--metrics" , action = "store_true" )
189+ parser .add_argument ("--force-max-length" , action = "store_true" , help = "Force generation to stop at max length" )
190+
191+ # Display parameters
196192 parser .add_argument ("--displayed" , type = int , default = 0 , help = "Number of samples to display" )
197193 parser .add_argument ("--log-level" , type = str , default = "INFO" )
198194 parser .add_argument ("--output-file" , type = str , default = None )
199- parser .add_argument ("--compare" , action = "store_true" )
200- parser .add_argument ("--metrics" , action = "store_true" )
201- parser .add_argument ("--profile" , type = str , default = None )
195+
202196 args = parser .parse_args ()
203197
204- # Set log level
205- logger .setLevel (args .log_level .upper ())
198+ # Create model
199+ model_id = "google/gemma-2-2b-it" if args .sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
200+ has_system_role = args .sliding_window == 0
201+
202+ model = AutoModelForCausalLM .from_pretrained (model_id , attn_implementation = args .attn , dtype = torch .bfloat16 )
203+ model = model .cuda ().eval ()
206204
207- # If turned on, we setup metrics
208- if args .metrics :
209- setup_metrics ()
205+ if args . sliding_window > 0 and getattr ( model . config , "sliding_window" , None ) is not None :
206+ print ( f"Setting sliding window from { model . config . sliding_window } to { args .sliding_window } " )
207+ model . config . sliding_window = args . sliding_window
210208
211- # Set matmul precision if not none
209+ # Set up diagnostics
210+ logger .setLevel (args .log_level .upper ())
211+ maybe_setup_metrics (args .metrics )
212+
213+ # Set up performance
212214 if args .matmul_precision != "none" :
213215 torch .set_float32_matmul_precision (args .matmul_precision )
214- # Parse cuda graph argument
215- if args .cuda_graph is not None :
216- use_cuda_graph = {
217- "none" : None ,
218- "yes" : True , "y" : True , "true" : True , "t" : True , "1" : True ,
219- "no" : False , "n" : False , "false" : False , "f" : False , "0" : False ,
220- }[args .cuda_graph .lower ()] # fmt: skip
221- else :
222- use_cuda_graph = None
223216
224- # Prepare model
225- model = AutoModelForCausalLM .from_pretrained (
226- MODEL_ID ,
227- attn_implementation = args .attn ,
228- dtype = torch .bfloat16 ,
229- )
230- model = model .cuda ().eval ()
231- if getattr (model .config , "sliding_window" , None ) is not None :
232- print (f"Setting sliding window from { model .config .sliding_window } to { SLIDING_WINDOW } " )
233- model .config .sliding_window = SLIDING_WINDOW
217+ cuda_graph_arg = args .cuda_graph .lower () if args .cuda_graph is not None else None
218+ use_cuda_graph = {
219+ "none" : None , None : None ,
220+ "yes" : True , "y" : True , "true" : True , "t" : True , "1" : True ,
221+ "no" : False , "n" : False , "false" : False , "f" : False , "0" : False ,
222+ }[cuda_graph_arg ] # fmt: skip
234223
235- # If turned on, we compile the model
236224 if args .compile :
237225 model .forward = torch .compile (model .forward , mode = "max-autotune-no-cudagraphs" )
238226
239227 # Prepare tokenizer and dataset
240- tokenizer = AutoTokenizer .from_pretrained (MODEL_ID , padding_side = "left" )
228+ tokenizer = AutoTokenizer .from_pretrained (model_id , padding_side = "left" )
241229
242230 dataset = datasets .load_dataset ("openai/gsm8k" , "socratic" , split = "test" )
243231 dataset = dataset .select (range (args .samples ))
244232
245- simple_batch_inputs = [tokenizer (item ["question" ])["input_ids" ] for item in dataset ]
233+ if args .add_prefix :
234+ possible_prefixes = [
235+ None ,
236+ "You are a bot that solves math problems." ,
237+ "You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning." ,
238+ "You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:" ,
239+ ] # fmt: skip
240+ else :
241+ possible_prefixes = [None ]
242+
243+ batched_inputs = []
244+ for item , prefix in zip (dataset , cycle (possible_prefixes )):
245+ messages = []
246+ question = item ["question" ]
247+ if prefix is not None :
248+ if has_system_role :
249+ messages .append ({"role" : "system" , "content" : prefix })
250+ else :
251+ question = prefix + "\n \n " + question
252+ messages .append ({"role" : "user" , "content" : question })
253+ inputs = tokenizer .apply_chat_template (messages , add_generation_prompt = True )
254+ batched_inputs .append (inputs ["input_ids" ])
246255
247256 # Prepare generation config
248- generation_config = GenerationConfig (
257+ generation_cfg = GenerationConfig (
249258 max_new_tokens = 512 ,
250259 use_cuda_graph = use_cuda_graph ,
251- eos_token_id = tokenizer .pad_token_id if FORCE_MAX_LENGTH else tokenizer .eos_token_id ,
260+ eos_token_id = tokenizer .pad_token_id if args . force_max_length else tokenizer .eos_token_id ,
252261 pad_token_id = tokenizer .pad_token_id ,
253- do_sample = not args .compare ,
262+ do_sample = args .do_sample ,
254263 temperature = 0.8 ,
255264 top_p = 0.9 ,
256265 num_blocks = args .num_blocks ,
257266 max_batch_tokens = args .max_batch_tokens ,
258267 )
259268
260269 # If we need to compare, we need to generate the reference outputs
261- expected_outputs = generate_simple (args .attn , simple_batch_inputs , generation_config ) if args .compare else None
270+ if args .compare :
271+ expected_outputs = generate_without_cb (
272+ model_id , args .sliding_window , args .attn , batched_inputs , generation_cfg
273+ )
274+ else :
275+ expected_outputs = None
262276
263277 # If no output file is provided, we pick a name based on the args
264278 if args .output_file is None :
@@ -271,8 +285,8 @@ def batch_generate(
271285 # Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
272286 batch_generate (
273287 model ,
274- simple_batch_inputs [: min (5 , args .samples )],
275- generation_config ,
288+ batched_inputs [: min (5 , args .samples )],
289+ generation_cfg ,
276290 tokenizer ,
277291 displayed_samples = - 1 ,
278292 )
@@ -285,8 +299,8 @@ def batch_generate(
285299 # Run batch generation
286300 gen_time , tok_per_sec = batch_generate (
287301 model ,
288- simple_batch_inputs ,
289- generation_config ,
302+ batched_inputs ,
303+ generation_cfg ,
290304 tokenizer ,
291305 displayed_samples = args .displayed ,
292306 output_file = args .output_file ,
@@ -297,5 +311,5 @@ def batch_generate(
297311 prof .export_chrome_trace (filename )
298312
299313# Example usage:
300- # python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
301- # python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 -- attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
314+ # python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
315+ # python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
0 commit comments