@@ -30,6 +30,31 @@ def get_cpu_arch():
3030 else :
3131 raise ValueError (f"Unsupported architecture: { arch } " )
3232
33+ # get device type
34+ def get_device_type ():
35+ import torch
36+
37+ if torch .cuda .is_available ():
38+ return "cuda"
39+
40+ try :
41+ import torch_mlu
42+ if torch .mlu .is_available ():
43+ return "mlu"
44+ except ImportError :
45+ pass
46+
47+ try :
48+ import torch_npu
49+ if torch .npu .is_available ():
50+ return "a2"
51+ except ImportError :
52+ pass
53+
54+ print ("Unsupported device type, please install torch, torch_mlu or torch_npu" )
55+ exit (1 )
56+
57+
3358def get_cxx_abi ():
3459 try :
3560 import torch
@@ -224,8 +249,6 @@ def set_cuda_envs():
224249 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
225250 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
226251 os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
227- os .environ ["NCCL_ROOT" ] = get_nccl_root_path ()
228- os .environ ["NCCL_VERSION" ] = "2"
229252
230253class CMakeExtension (Extension ):
231254 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
@@ -551,39 +574,46 @@ def pre_build():
551574 if not run_shell_command ("sh third_party/dependencies.sh" , cwd = script_path ):
552575 print ("❌ Failed to reset changes!" )
553576 exit (0 )
577+
578+ def parse_arguments ():
579+ parser = argparse .ArgumentParser (add_help = False )
580+
581+ parser .add_argument (
582+ '--device' ,
583+ type = str .lower ,
584+ choices = ['a2' , 'a3' , 'mlu' , 'cuda' ],
585+ default = 'auto' ,
586+ help = 'Device type: a2, a3, mlu, or cuda (case-insensitive)'
587+ )
588+
589+ parser .add_argument (
590+ '--dry-run' ,
591+ action = 'store_true' ,
592+ help = 'Dry run mode (do not execute pre_build)'
593+ )
594+
595+ parser .add_argument (
596+ '--install-xllm-kernels' ,
597+ type = str .lower ,
598+ choices = ['true' , 'false' , '1' , '0' , 'yes' , 'no' , 'y' , 'n' , 'on' , 'off' ],
599+ default = 'true' ,
600+ help = 'Whether to install XLLM kernels'
601+ )
602+
603+ return parser
554604
555605if __name__ == "__main__" :
556- device = 'a2' # default
606+ args = parse_arguments ()
557607 arch = get_cpu_arch ()
558- install_kernels = True
559- if '--device' in sys .argv :
560- idx = sys .argv .index ('--device' )
561- if idx + 1 < len (sys .argv ):
562- device = sys .argv [idx + 1 ].lower ()
563- if device not in ('a2' , 'a3' , 'mlu' , 'cuda' ):
564- print ("Error: --device must be a2 or a3 or mlu (case-insensitive)" )
565- sys .exit (1 )
566- # Remove the arguments so setup() doesn't see them
567- del sys .argv [idx ]
568- del sys .argv [idx ]
569- if '--dry_run' not in sys .argv :
570- pre_build ()
571- else :
572- sys .argv .remove ("--dry_run" )
608+ device = args .device
609+ if device == 'auto' :
610+ device = get_device_type ()
611+ print (f"🚀 Build xllm with CPU arch: { arch } and target device: { device } " )
573612
574- if '--install-xllm-kernels' in sys .argv :
575- idx = sys .argv .index ('--install-xllm-kernels' )
576- if idx + 1 < len (sys .argv ):
577- install_kernels = sys .argv [idx + 1 ].lower ()
578- if install_kernels in ('true' , '1' , 'yes' , 'y' , 'on' ):
579- install_kernels = True
580- elif install_kernels in ('false' , '0' , 'no' , 'n' , 'off' ):
581- install_kernels = False
582- else :
583- print ("Error: --install-xllm-kernels must be true or false" )
584- sys .exit (1 )
585- sys .argv .pop (idx )
586- sys .argv .pop (idx )
613+ if not args .dry_run :
614+ pre_build ()
615+
616+ install_kernels = args .install_xllm_kernels in ('true' , '1' , 'yes' , 'y' , 'on' )
587617
588618 if "SKIP_TEST" in os .environ :
589619 BUILD_TEST_FILE = False
0 commit comments