1313import torch ._dynamo .config
1414import torch ._inductor .config
1515
16+ def device_sync (device ):
17+ if "cuda" in device :
18+ torch .cuda .synchronize ()
19+ elif "cpu" in device :
20+ pass
21+ else :
22+ print (f"device={ device } is not yet suppported" )
23+
24+
1625torch ._inductor .config .coordinate_descent_tuning = True
1726torch ._inductor .config .triton .unique_kernel_names = True
1827torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -65,11 +74,12 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
6574 next_token , next_prob = decode_one_token (
6675 model , cur_token , input_pos , ** sampling_kwargs
6776 )
68- input_pos += 1
69- new_tokens .append (next_token .clone ())
70- callback (new_tokens [- 1 ])
71- new_probs .append (next_prob .clone ())
72- cur_token = next_token .view (1 , - 1 )
77+ input_pos += 1
78+ new_tokens .append (next_token .clone ())
79+ callback (new_tokens [- 1 ])
80+ new_probs .append (next_prob .clone ())
81+ cur_token = next_token .view (1 , - 1 )
82+
7383 return new_tokens , new_probs
7484
7585
@@ -248,6 +258,7 @@ def main(
248258 profile : Optional [Path ] = None ,
249259 draft_checkpoint_path : Optional [Path ] = None ,
250260 speculate_k : int = 5 ,
261+ device = 'cuda' ,
251262) -> None :
252263 """Generates text samples based on a pre-trained Transformer model and tokenizer.
253264 """
@@ -264,7 +275,7 @@ def main(
264275 # only print on rank 0
265276 print = lambda * args , ** kwargs : None
266277
267- device = 'cuda'
278+ print ( f"Using device= { device } " )
268279 precision = torch .bfloat16
269280 is_speculative = draft_checkpoint_path is not None
270281 is_chat = "chat" in str (checkpoint_path )
@@ -278,7 +289,7 @@ def main(
278289 else :
279290 draft_model = None
280291
281- torch . cuda . synchronize ()
292+ device_sync ( device = device ) # MKG
282293 print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
283294
284295 tokenizer = SentencePieceProcessor (model_file = str (tokenizer_path ))
@@ -288,7 +299,7 @@ def main(
288299 torch .manual_seed (1234 )
289300 model_size = sum ([p .numel () * p .dtype .itemsize for p in itertools .chain (model .parameters (), model .buffers ())])
290301 if compile :
291- if is_speculative and use_tp :
302+ if is_speculative and use_tp : # and ("cuda" in device):
292303 torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
293304
294305 if is_speculative :
@@ -310,7 +321,7 @@ def main(
310321 start = - 1 if compile else 0
311322
312323 for i in range (start , num_samples ):
313- torch . cuda . synchronize ()
324+ device_sync ( device = device ) # MKG
314325 if i >= 0 and interactive :
315326 prompt = input ("What is your prompt? " )
316327 if is_chat :
@@ -362,7 +373,7 @@ def callback(x):
362373 prof .export_chrome_trace (f"{ profile } _rank_{ rank } .json" )
363374 else :
364375 prof .export_chrome_trace (f"{ profile } .json" )
365- torch . cuda . synchronize ()
376+ device_sync ( device = device ) # MKG
366377 t = time .perf_counter () - t0
367378
368379 if not interactive :
@@ -401,9 +412,11 @@ def callback(x):
401412 parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
402413 parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
403414 parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
415+ parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
404416
405417 args = parser .parse_args ()
406418 main (
407419 args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
408- args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path , args .speculate_k
420+ args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
421+ args .speculate_k , args .device
409422 )
0 commit comments