77import sys
88import time
99from pathlib import Path
10- from typing import Optional , Tuple
10+ from typing import Optional , Tuple , Union
1111
1212import torch
1313import torch ._dynamo .config
@@ -24,7 +24,9 @@ def device_sync(device):
2424
2525torch ._inductor .config .coordinate_descent_tuning = True
2626torch ._inductor .config .triton .unique_kernel_names = True
27- torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27+ # Experimental features to reduce compilation times, will be on by default in future
28+ torch ._inductor .config .fx_graph_cache = True
29+ torch ._functorch .config .enable_autograd_cache = True
2830
2931default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
3032
@@ -50,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
5052 return probs
5153
5254def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
53- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
55+ probs = logits_to_probs (logits [: , - 1 ], temperature , top_k )
5456 idx_next = multinomial_sample_one_no_sync (probs )
5557 return idx_next , probs
5658
@@ -76,7 +78,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7678 new_tokens .append (next_token .clone ())
7779 callback (new_tokens [- 1 ])
7880 new_probs .append (next_prob .clone ())
79- cur_token = next_token .view ( 1 , - 1 )
81+ cur_token = next_token .clone ( )
8082
8183 return new_tokens , new_probs
8284
@@ -139,6 +141,7 @@ def generate(
139141 model : Transformer ,
140142 prompt : torch .Tensor ,
141143 max_new_tokens : int ,
144+ batch_size : int ,
142145 * ,
143146 interactive : bool ,
144147 draft_model : Transformer ,
@@ -152,7 +155,7 @@ def generate(
152155
153156 is_speculative = draft_model is not None
154157 # create an empty tensor of the expected final shape and fill in the current tokens
155- T = prompt .size (0 )
158+ T = prompt .size (- 1 )
156159 T_new = T + max_new_tokens
157160 if interactive :
158161 max_seq_length = 350
@@ -162,20 +165,22 @@ def generate(
162165 device , dtype = prompt .device , prompt .dtype
163166 max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164167 with torch .device (device ):
165- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
168+ model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
166169 if is_speculative and draft_model is not model :
167- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
170+ draft_model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
168171
169172 # create an empty tensor of the expected final shape and fill in the current tokens
170- empty = torch .empty (T_new , dtype = dtype , device = device )
171- empty [:T ] = prompt
173+ empty = torch .empty (batch_size , T_new , dtype = dtype , device = device )
174+ # We are just making the same prompt for every batch
175+ prompt = prompt .view (1 , - 1 ).repeat (batch_size , 1 )
176+ empty [:, :T ] = prompt
172177 seq = empty
173178 input_pos = torch .arange (0 , T , device = device )
174179
175- next_token = prefill (model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs ).clone ()
180+ next_token = prefill (model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs ).clone ()
176181 if is_speculative :
177- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
178- seq [T ] = next_token
182+ prefill (draft_model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs )
183+ seq [:, T ] = next_token . squeeze ()
179184
180185 input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
181186 accept_counts = [0 ] * (speculate_k + 1 )
@@ -197,8 +202,8 @@ def generate(
197202 input_pos = input_pos + num_added
198203 next_token = next_tokens [- 1 ]
199204 else :
200- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
201- seq [T + 1 :] = torch .cat (generated_tokens )
205+ generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
206+ seq [:, T + 1 :] = torch .cat (generated_tokens , dim = - 1 )
202207
203208 generate_stats = {
204209 'accept_counts' : accept_counts
@@ -245,6 +250,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
245250
246251def _get_model_size (model ):
247252 model_size = 0
253+ params = 0
248254 for name , child in model .named_children ():
249255 if not isinstance (child , torch .nn .Embedding ):
250256 model_size += sum (
@@ -253,15 +259,22 @@ def _get_model_size(model):
253259 for p in itertools .chain (child .parameters (), child .buffers ())
254260 ]
255261 )
256- return model_size
262+ params += sum (
263+ [
264+ p .numel ()
265+ for p in itertools .chain (child .parameters (), child .buffers ())
266+ ]
267+ )
268+ return model_size , params
257269
258270B_INST , E_INST = "[INST]" , "[/INST]"
259271
260272def main (
261- prompt : str = "Hello, my name is" ,
273+ prompt : Union [ int , str ] = "Hello, my name is" ,
262274 interactive : bool = False ,
263275 num_samples : int = 5 ,
264276 max_new_tokens : int = 100 ,
277+ batch_size : int = 1 ,
265278 top_k : int = 200 ,
266279 temperature : float = 0.8 ,
267280 checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
@@ -307,11 +320,15 @@ def main(
307320
308321 tokenizer = get_tokenizer (tokenizer_path , checkpoint_path )
309322
310- encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
311- prompt_length = encoded .size (0 )
323+ if isinstance (prompt , str ):
324+ encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
325+ else :
326+ # generate a fully synthetic prompt
327+ encoded = torch .randint (0 , 1024 , (prompt ,), device = device , dtype = torch .int64 )
328+ prompt_length = encoded .size (- 1 )
312329
313330 torch .manual_seed (1234 )
314- model_size = _get_model_size (model )
331+ model_size , params = _get_model_size (model )
315332 if compile :
316333 if is_speculative and use_tp : # and ("cuda" in device):
317334 torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
@@ -371,6 +388,7 @@ def callback(x):
371388 model ,
372389 encoded ,
373390 max_new_tokens ,
391+ batch_size = batch_size ,
374392 draft_model = draft_model ,
375393 speculate_k = speculate_k ,
376394 interactive = interactive ,
@@ -391,21 +409,30 @@ def callback(x):
391409 t = time .perf_counter () - t0
392410
393411 if not interactive :
394- print (tokenizer .decode (y .tolist ()))
412+ # Just displaying the first generation
413+ if batch_size > 1 :
414+ print ("Only displaying the first generation of the batch" )
415+ print (tokenizer .decode (y [0 ].tolist ()))
395416 else :
396417 print ()
397- tokens_generated = y .size (0 ) - prompt_length
398- tokens_sec = tokens_generated / t
399- aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
400- print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
401- print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
418+ tokens_generated = y .size (- 1 ) - prompt_length
419+ generated_tokens_sec = tokens_generated / t
420+ aggregate_metrics ['tokens_per_sec' ].append (generated_tokens_sec )
421+ print (f"Time for inference { i + 1 } : { t :.02f} sec total, { generated_tokens_sec :.02f} tokens/sec" )
422+ print (f"Bandwidth achieved: { model_size * generated_tokens_sec / 1e9 :.02f} GB/s" )
423+ total_tokens_sec = y .numel () / t
424+ print (f"FLOPS achieved: { params * total_tokens_sec * 2 / 1e12 :.02f} TF/s" )
425+ print ()
402426 print ("==========" )
403427 if is_speculative :
404428 counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ['accept_counts' ])]
405429 acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
406430 print (f"Acceptance probs: { acceptance_probs } " )
407431 print (f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} " )
408432
433+ print (f"Batch Size: { batch_size } " )
434+ print (f"Prompt Length: { prompt_length } " )
435+ print (f"Generated tokens: { max_new_tokens } " )
409436 print (f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} " )
410437 print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
411438
@@ -414,10 +441,17 @@ def callback(x):
414441 import argparse
415442 parser = argparse .ArgumentParser (description = 'Your CLI description.' )
416443
417- parser .add_argument ('--prompt' , type = str , default = "Hello, my name is" , help = 'Input prompt.' )
444+ def int_or_str (x ):
445+ try :
446+ return int (x )
447+ except :
448+ return x
449+
450+ parser .add_argument ('--prompt' , type = int_or_str , default = "Hello, my name is" , help = "Input prompt. If it's an integer, will instead generate a synthetic prompt." )
418451 parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
419452 parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
420453 parser .add_argument ('--max_new_tokens' , type = int , default = 200 , help = 'Maximum number of new tokens.' )
454+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to benchmark with' )
421455 parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
422456 parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
423457 parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
@@ -430,7 +464,7 @@ def callback(x):
430464
431465 args = parser .parse_args ()
432466 main (
433- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
467+ args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args . top_k ,
434468 args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
435469 args .speculate_k , args .device
436470 )
0 commit comments