diff --git a/CMakeLists.txt b/CMakeLists.txt index f18bca2ff..052cc3819 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -344,6 +344,7 @@ if(USE_NPU) $ENV{PYTORCH_INSTALL_PATH}/include $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include $ENV{PYTORCH_NPU_INSTALL_PATH}/include + $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed $ENV{NPU_HOME_PATH}/include $ENV{ATB_HOME_PATH}/include $ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/ diff --git a/xllm/CMakeLists.txt b/xllm/CMakeLists.txt index b31f3f239..0c86f08c6 100644 --- a/xllm/CMakeLists.txt +++ b/xllm/CMakeLists.txt @@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p add_dependencies(xllm brpc-static) if(USE_NPU) - set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext) + set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext torch_npu torch_python) elseif(USE_MLU) set(COMMON_LIBS Python::Python) endif() diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 715dd2308..9b030cf02 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -464,3 +464,10 @@ DEFINE_bool(enable_constrained_decoding, "Whether to enable constrained decoding, which is used to ensure " "that the output meets specific format or structural requirements " "through pre-defined rules."); + +#if defined(USE_NPU) +DEFINE_string( + npu_kernel_backend, + "ATB", + "NPU kernel backend. Supported options: ATB, TORCH. Default is ATB."); +#endif diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 417ed41bd..95326fbd7 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -226,3 +226,7 @@ DECLARE_int64(dit_cache_skip_interval_steps); DECLARE_double(dit_cache_residual_diff_threshold); DECLARE_bool(enable_constrained_decoding); + +#if defined(USE_NPU) +DECLARE_string(npu_kernel_backend); +#endif diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 0177d779f..7f2ed9cb0 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -104,9 +104,7 @@ void WorkerServer::create_server( CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size); const ParallelArgs* parallel_args = comm.parallel_args(); -#if defined(USE_MLU) || defined(USE_CUDA) comm.create_process_groups(master_node_addr, device); -#endif std::unique_ptr worker = std::make_unique(*parallel_args, device, options, worker_type); diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index c0066be0a..2f08fb389 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -18,9 +18,9 @@ limitations under the License. #include "mapping_npu.h" #if defined(USE_NPU) +#include "npu_process_group.h" #include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h" #include "xllm_kernels/core/include/atb_speed/utils/singleton.h" -#include "xllm_kernels/models/base/param/mapping.h" #elif defined(USE_MLU) #include "mlu_process_group.h" #elif defined(USE_CUDA) @@ -28,25 +28,9 @@ limitations under the License. #endif #include "common/global_flags.h" #include "parallel_args.h" +#include "process_group.h" #include "util/net.h" -namespace { -#if defined(USE_NPU) -std::unique_ptr create_process_group( - int rank, - int world_size, - int rank_size, - int port, - bool trans, - const std::string& host, - const std::string& group_name, - const torch::Device& device) { - LOG(FATAL) << "Unsupported device type"; - return nullptr; -} -#endif -} // namespace - namespace xllm { CollectiveCommunicator::CollectiveCommunicator(int global_rank, @@ -72,6 +56,13 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, // std::make_unique( // global_rank, world_size, device, comm); + // comunicator will be inited in torch. + if (FLAGS_npu_kernel_backend == "TORCH") { + parallel_args_ = std::make_unique( + global_rank, world_size, dp_size, nullptr, ep_size); + return; + } + // comunicator will be inited in atb. MappingNPU::Options mapping_options; mapping_options.dp_size(dp_size) @@ -116,6 +107,11 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, void CollectiveCommunicator::create_process_groups( const std::string& master_addr, const torch::Device& device) { +#if defined(USE_NPU) + if (FLAGS_npu_kernel_backend == "ATB") { + return; + } +#endif std::string host; int port; net::parse_host_port_from_addr(master_addr, host, port); diff --git a/xllm/core/framework/parallel_state/cuda_process_group.h b/xllm/core/framework/parallel_state/cuda_process_group.h index 349cf0083..bc71da369 100644 --- a/xllm/core/framework/parallel_state/cuda_process_group.h +++ b/xllm/core/framework/parallel_state/cuda_process_group.h @@ -21,12 +21,12 @@ limitations under the License. namespace xllm { -class ProcessGroupNccl : public ProcessGroup { +class ProcessGroupImpl : public ProcessGroup { public: - ProcessGroupNccl(int global_rank, - int world_size, - int rank_size, - int port, + ProcessGroupImpl(int32_t global_rank, + int32_t world_size, + int32_t rank_size, + int32_t port, bool trans, const std::string& host, const std::string& group_name, @@ -34,10 +34,11 @@ class ProcessGroupNccl : public ProcessGroup { : ProcessGroup(device) { c10::intrusive_ptr pg_options = c10d::ProcessGroupNCCL::Options::create(); -#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) pg_options->group_name = group_name; #endif - int rank = global_rank; + int32_t rank = global_rank; if (world_size != rank_size) { auto [local_rank, group_ranks] = get_group_rank(world_size, global_rank, rank_size, trans); @@ -51,16 +52,4 @@ class ProcessGroupNccl : public ProcessGroup { } }; -std::unique_ptr create_process_group( - int rank, - int world_size, - int rank_size, - int port, - bool trans, - const std::string& host, - const std::string& group_name, - const torch::Device& device) { - return std::make_unique( - rank, world_size, rank_size, port, trans, host, group_name, device); -} } // namespace xllm diff --git a/xllm/core/framework/parallel_state/mlu_process_group.h b/xllm/core/framework/parallel_state/mlu_process_group.h index c2b0bf711..525ac0205 100644 --- a/xllm/core/framework/parallel_state/mlu_process_group.h +++ b/xllm/core/framework/parallel_state/mlu_process_group.h @@ -23,9 +23,9 @@ namespace xllm { constexpr int32_t local_device_count = 8; -class ProcessGroupCncl : public ProcessGroup { +class ProcessGroupImpl : public ProcessGroup { public: - ProcessGroupCncl(int32_t global_rank, + ProcessGroupImpl(int32_t global_rank, int32_t world_size, int32_t rank_size, int32_t port, @@ -57,17 +57,4 @@ class ProcessGroupCncl : public ProcessGroup { } }; -std::unique_ptr create_process_group( - int32_t rank, - int32_t world_size, - int32_t rank_size, - int32_t port, - bool trans, - const std::string& host, - const std::string& group_name, - const torch::Device& device) { - return std::make_unique( - rank, world_size, rank_size, port, trans, host, group_name, device); -} - } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/parallel_state/npu_process_group.cpp b/xllm/core/framework/parallel_state/npu_process_group.cpp index fceaa9d00..c09c06a10 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.cpp +++ b/xllm/core/framework/parallel_state/npu_process_group.cpp @@ -15,6 +15,10 @@ limitations under the License. #include "npu_process_group.h" +#include +#include +#include + namespace { #define HCCLCHECK(cmd) \ @@ -24,113 +28,55 @@ namespace { LOG(FATAL) << "Failed, HCCL error :" << HcclGetErrorString(r); \ } \ } while (0) +} // namespace -inline bool is_npu(const at::Tensor& tensor) { - if (!tensor.defined()) { - return false; - } - return tensor.device().is_privateuseone(); -} - -inline bool is_npu(const at::TensorOptions& options) { - return options.device().is_privateuseone(); -} +namespace xllm { -inline bool is_npu(const at::Device& device) { - return device.is_privateuseone(); -} +ProcessGroupImpl::ProcessGroupImpl(int32_t global_rank, + int32_t world_size, + int32_t rank_size, + int32_t port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(device) { + c10::intrusive_ptr hccl_pg_options = + c10d_npu::ProcessGroupHCCL::Options::create(); +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) + hccl_pg_options->group_name = group_name; +#endif + int32_t rank = global_rank; + if (world_size != rank_size) { + auto [local_rank, group_ranks] = + get_group_rank(world_size, global_rank, rank_size, trans); + std::vector uint32_ranks; + for (auto rank : group_ranks) { + uint32_ranks.push_back(static_cast(rank)); + } + hccl_pg_options->global_ranks_in_group = uint32_ranks; + rank = local_rank; + } -at::Tensor flatten_for_scatter_gather(std::vector& tensors) { - auto& t = tensors[0]; - std::vector sizes{static_cast(tensors.size())}; - sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); - return at::empty(sizes, t.options()); + auto store = create_tcp_store(host, port, rank); + pg_ = std::make_unique( + store, rank, rank_size, hccl_pg_options); } -HcclDataType to_hccl_data_type(const torch::Tensor& input) { - const auto type = input.scalar_type(); - switch (type) { - case at::kFloat: - return HCCL_DATA_TYPE_FP32; - case at::kHalf: - return HCCL_DATA_TYPE_FP16; - case at::kDouble: - return HCCL_DATA_TYPE_FP64; - case at::kLong: - return HCCL_DATA_TYPE_INT64; - case at::kInt: - return HCCL_DATA_TYPE_INT32; - case at::kChar: - return HCCL_DATA_TYPE_INT8; - case at::kByte: - return HCCL_DATA_TYPE_UINT8; - case at::kBool: - return HCCL_DATA_TYPE_UINT8; - case at::kBFloat16: - return HCCL_DATA_TYPE_BFP16; - default: - LOG(FATAL) << "Unconvertible HCCL type: " << type; +// Destructor. +ProcessGroupImpl::~ProcessGroupImpl() { + if (pg_) { + pg_->shutdown(); + } else { + HCCLCHECK(HcclCommDestroy(comm_)); } } -void check_input(torch::Tensor input) { - CHECK(is_npu(input)) << "input should be npu tensor"; - CHECK(input.is_contiguous()) << "input should be contiguous"; - CHECK(!input.is_sparse()) << "input have to be npu dense tensor"; -} - -} // namespace - -namespace xllm { - -ProcessGroupHCCL::ProcessGroupHCCL(int rank, +ProcessGroupImpl::ProcessGroupImpl(int rank, int world_size, const torch::Device& device, HcclComm comm) : ProcessGroup(device), comm_(comm) {} -// Destructor. -ProcessGroupHCCL::~ProcessGroupHCCL() { HCCLCHECK(HcclCommDestroy(comm_)); } -void ProcessGroupHCCL::allreduce(torch::Tensor& input) { - DCHECK(input.device() == device()) - << "input should be on the same device as the process group"; - check_input(input); - // inplace all reduce - // const auto count = input.numel(); - // const auto data_type = to_hccl_data_type(input); - // auto stream = c10_npu::getCurrentNPUStream(); - // torch::DeviceGuard device_guard(device()); - // HCCLCHECK(HcclAllReduce( - // /*sendbuff=*/input.data_ptr(), - // /*recvbuff=*/input.data_ptr(), - // /*count=*/count, - // /*datatype=*/data_type, - // /*op=*/HCCL_REDUCE_SUM, - // /*comm=*/comm_, - // /*stream=*/stream)); -} -void ProcessGroupHCCL::allgather(const torch::Tensor& input, - std::vector& outputs) { - check_input(input); - // CHECK(outputs.size() == world_size()) - // << "outputs should have the same size as world_size"; - // DCHECK(input.device() == device()) - // << "input should be on the same device as the process group"; - // torch::DeviceGuard device_guard(device()); - // torch::Tensor flattened_output = flatten_for_scatter_gather(outputs); - // const auto count = input.numel(); - // const auto data_type = to_hccl_data_type(input); - // auto stream = c10_npu::getCurrentNPUStream(); - // HCCLCHECK(HcclAllGather( - // /*sendbuff=*/input.data_ptr(), - // /*recvbuff=*/flattened_output.data_ptr(), - // /*sendcount=*/count, - // /*datatype=*/data_type, - // /*comm=*/comm_, - // /*stream=*/stream)); - // // copy the flattened output tensors to the outputs. - // for (int i = 0; i < outputs.size(); ++i) { - // outputs[i].copy_(flattened_output[i], /*non_blocking=*/true); - // } -} } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/parallel_state/npu_process_group.h b/xllm/core/framework/parallel_state/npu_process_group.h index 7ca7d23b7..daf9fa112 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.h +++ b/xllm/core/framework/parallel_state/npu_process_group.h @@ -20,21 +20,25 @@ limitations under the License. namespace xllm { -class ProcessGroupHCCL : public ProcessGroup { +class ProcessGroupImpl : public ProcessGroup { public: // Constructor. - ProcessGroupHCCL(int rank, + ProcessGroupImpl(int rank, int world_size, const torch::Device& device, HcclComm comm); - // Destructor. - ~ProcessGroupHCCL() override; - - void allreduce(torch::Tensor& input) override; + ProcessGroupImpl(int rank, + int world_size, + int rank_size, + int port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device); - void allgather(const torch::Tensor& input, - std::vector& outputs) override; + // Destructor. + ~ProcessGroupImpl() override; private: HcclComm comm_ = nullptr; diff --git a/xllm/core/framework/parallel_state/parallel_state.cpp b/xllm/core/framework/parallel_state/parallel_state.cpp index 0678b188f..84cfb7eb9 100644 --- a/xllm/core/framework/parallel_state/parallel_state.cpp +++ b/xllm/core/framework/parallel_state/parallel_state.cpp @@ -215,7 +215,7 @@ std::vector> create_npu_process_groups( std::vector> process_groups; process_groups.reserve(devices.size()); for (int i = 0; i < world_size; ++i) { - process_groups.emplace_back(std::make_unique( + process_groups.emplace_back(std::make_unique( /*rank=*/i, world_size, devices[i], comms[i])); } diff --git a/xllm/core/framework/parallel_state/process_group.cpp b/xllm/core/framework/parallel_state/process_group.cpp index f43c09e39..1b8789305 100644 --- a/xllm/core/framework/parallel_state/process_group.cpp +++ b/xllm/core/framework/parallel_state/process_group.cpp @@ -15,6 +15,14 @@ limitations under the License. #include "process_group.h" +#if defined(USE_NPU) +#include "npu_process_group.h" +#elif defined(USE_MLU) +#include "mlu_process_group.h" +#elif defined(USE_CUDA) +#include "cuda_process_group.h" +#endif + namespace { std::pair> get_trans_group_rank(int world_size, int global_rank, @@ -75,4 +83,18 @@ void ProcessGroup::allgather(const torch::Tensor& input, std::vector> output_tensors = {outputs}; pg_->allgather(output_tensors, input_tensors)->wait(); } + +std::unique_ptr create_process_group( + int32_t rank, + int32_t world_size, + int32_t rank_size, + int32_t port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device) { + return std::make_unique( + rank, world_size, rank_size, port, trans, host, group_name, device); +} + } // namespace xllm diff --git a/xllm/core/framework/parallel_state/process_group.h b/xllm/core/framework/parallel_state/process_group.h index ba1d67a9e..25fe00cab 100644 --- a/xllm/core/framework/parallel_state/process_group.h +++ b/xllm/core/framework/parallel_state/process_group.h @@ -19,7 +19,15 @@ limitations under the License. #include #include + +#if defined(USE_NPU) +#include +#endif + namespace xllm { + +class ProcessGroupImpl; + std::pair> get_group_rank(int world_size, int global_rank, int split_size, @@ -60,7 +68,26 @@ class ProcessGroup { torch::Device device_; protected: +#if defined(USE_NPU) && \ + (TORCH_VERSION_MAJOR < 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 7)) + // Using ProcessGroupHCCL for NPU devices + // Note: torch_npu uses an older torch version where c10d::Backend lacks + // shutdown() method + std::unique_ptr pg_{nullptr}; +#else std::unique_ptr pg_{nullptr}; +#endif }; +std::unique_ptr create_process_group( + int32_t rank, + int32_t world_size, + int32_t rank_size, + int32_t port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device); + } // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/common/tests/tests_utils.h b/xllm/core/layers/common/tests/tests_utils.h index 8fdf56f5b..923d48937 100644 --- a/xllm/core/layers/common/tests/tests_utils.h +++ b/xllm/core/layers/common/tests/tests_utils.h @@ -125,7 +125,8 @@ class MockBackend : public c10d::Backend { int64_t getSize() const { return world_size_; } -#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) void shutdown() override { // Mock implementation - do nothing }