1616def device_sync (device ):
1717 if "cuda" in device :
1818 torch .cuda .synchronize (device )
19- elif "cpu" in device :
19+ elif ( "cpu" in device ) or ( "mps" in device ) :
2020 pass
2121 else :
2222 print (f"device={ device } is not yet suppported" )
@@ -26,6 +26,7 @@ def device_sync(device):
2626torch ._inductor .config .triton .unique_kernel_names = True
2727torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2828
29+ default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
2930
3031# support running without installing as a package
3132wd = Path (__file__ ).parent .parent .resolve ()
@@ -206,7 +207,7 @@ def generate(
206207 }
207208 return seq , generate_stats
208209
209- def encode_tokens (tokenizer , string , bos = True , device = 'cuda' ):
210+ def encode_tokens (tokenizer , string , bos = True , device = default_device ):
210211 tokens = tokenizer .encode (string )
211212 if bos :
212213 tokens = [tokenizer .bos_id ()] + tokens
@@ -259,7 +260,7 @@ def main(
259260 profile : Optional [Path ] = None ,
260261 draft_checkpoint_path : Optional [Path ] = None ,
261262 speculate_k : int = 5 ,
262- device = 'cuda' ,
263+ device = default_device ,
263264) -> None :
264265 """Generates text samples based on a pre-trained Transformer model and tokenizer.
265266 """
@@ -414,7 +415,7 @@ def callback(x):
414415 parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
415416 parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
416417 parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
417- parser .add_argument ('--device' , type = str , default = "cuda" , help = 'Device to use' )
418+ parser .add_argument ('--device' , type = str , default = default_device , help = 'Device to use' )
418419
419420 args = parser .parse_args ()
420421 main (
0 commit comments