1111from pathlib import Path
1212from typing import List
1313from jinja2 import Template
14+ import argparse
1415
1516from distutils .core import Command
1617from setuptools import Extension , setup , find_packages
@@ -30,6 +31,31 @@ def get_cpu_arch():
3031 else :
3132 raise ValueError (f"Unsupported architecture: { arch } " )
3233
34+ # get device type
35+ def get_device_type ():
36+ import torch
37+
38+ if torch .cuda .is_available ():
39+ return "cuda"
40+
41+ try :
42+ import torch_mlu
43+ if torch .mlu .is_available ():
44+ return "mlu"
45+ except ImportError :
46+ pass
47+
48+ try :
49+ import torch_npu
50+ if torch .npu .is_available ():
51+ return "a2"
52+ except ImportError :
53+ pass
54+
55+ print ("Unsupported device type, please install torch, torch_mlu or torch_npu" )
56+ exit (1 )
57+
58+
3359def get_cxx_abi ():
3460 try :
3561 import torch
@@ -224,8 +250,6 @@ def set_cuda_envs():
224250 os .environ ["LIBTORCH_ROOT" ] = get_torch_root_path ()
225251 os .environ ["PYTORCH_INSTALL_PATH" ] = get_torch_root_path ()
226252 os .environ ["CUDA_TOOLKIT_ROOT_DIR" ] = "/usr/local/cuda"
227- os .environ ["NCCL_ROOT" ] = get_nccl_root_path ()
228- os .environ ["NCCL_VERSION" ] = "2"
229253
230254class CMakeExtension (Extension ):
231255 def __init__ (self , name : str , path : str , sourcedir : str = "" ) -> None :
@@ -551,39 +575,48 @@ def pre_build():
551575 if not run_shell_command ("sh third_party/dependencies.sh" , cwd = script_path ):
552576 print ("❌ Failed to reset changes!" )
553577 exit (0 )
578+
579+ def parse_arguments ():
580+ parser = argparse .ArgumentParser (add_help = False )
581+
582+ parser .add_argument (
583+ '--device' ,
584+ type = str .lower ,
585+ choices = ['a2' , 'a3' , 'mlu' , 'cuda' ],
586+ default = 'auto' ,
587+ help = 'Device type: a2, a3, mlu, or cuda (case-insensitive)'
588+ )
589+
590+ parser .add_argument (
591+ '--dry-run' ,
592+ action = 'store_true' ,
593+ help = 'Dry run mode (do not execute pre_build)'
594+ )
595+
596+ parser .add_argument (
597+ '--install-xllm-kernels' ,
598+ type = str .lower ,
599+ choices = ['true' , 'false' , '1' , '0' , 'yes' , 'no' , 'y' , 'n' , 'on' , 'off' ],
600+ default = 'true' ,
601+ help = 'Whether to install xllm kernels'
602+ )
603+
604+ return parser
554605
555606if __name__ == "__main__" :
556- device = 'a2' # default
607+ parser = parse_arguments ()
608+ args = parser .parse_args ()
609+
557610 arch = get_cpu_arch ()
558- install_kernels = True
559- if '--device' in sys .argv :
560- idx = sys .argv .index ('--device' )
561- if idx + 1 < len (sys .argv ):
562- device = sys .argv [idx + 1 ].lower ()
563- if device not in ('a2' , 'a3' , 'mlu' , 'cuda' ):
564- print ("Error: --device must be a2 or a3 or mlu (case-insensitive)" )
565- sys .exit (1 )
566- # Remove the arguments so setup() doesn't see them
567- del sys .argv [idx ]
568- del sys .argv [idx ]
569- if '--dry_run' not in sys .argv :
570- pre_build ()
571- else :
572- sys .argv .remove ("--dry_run" )
611+ device = args .device
612+ if device == 'auto' :
613+ device = get_device_type ()
614+ print (f"🚀 Build xllm with CPU arch: { arch } and target device: { device } " )
573615
574- if '--install-xllm-kernels' in sys .argv :
575- idx = sys .argv .index ('--install-xllm-kernels' )
576- if idx + 1 < len (sys .argv ):
577- install_kernels = sys .argv [idx + 1 ].lower ()
578- if install_kernels in ('true' , '1' , 'yes' , 'y' , 'on' ):
579- install_kernels = True
580- elif install_kernels in ('false' , '0' , 'no' , 'n' , 'off' ):
581- install_kernels = False
582- else :
583- print ("Error: --install-xllm-kernels must be true or false" )
584- sys .exit (1 )
585- sys .argv .pop (idx )
586- sys .argv .pop (idx )
616+ if not args .dry_run :
617+ pre_build ()
618+
619+ install_kernels = args .install_xllm_kernels in ('true' , '1' , 'yes' , 'y' , 'on' )
587620
588621 if "SKIP_TEST" in os .environ :
589622 BUILD_TEST_FILE = False
0 commit comments