From 8c0ada70c1bf57979b78ffa1a64518c950622841 Mon Sep 17 00:00:00 2001 From: Clement-Wang26 Date: Wed, 26 Nov 2025 16:53:46 +0800 Subject: [PATCH] feat: support loading model weights and forward overlap. --- xllm/core/framework/hf_model_loader.cpp | 2 +- xllm/core/layers/base_layer.cpp | 93 ++++-- xllm/core/layers/base_layer.h | 10 +- xllm/core/layers/npu/npu_base_layer.cpp | 178 +++++++++++ xllm/core/layers/npu/npu_base_layer.h | 31 +- .../npu_deepseek_v2_decoder_layer_impl.cpp | 290 +++++++++++------- .../npu/npu_deepseek_v2_decoder_layer_impl.h | 7 +- .../npu/npu_qwen2_decoder_layer_impl.cpp | 189 +++++++----- .../layers/npu/npu_qwen2_decoder_layer_impl.h | 7 +- .../npu/npu_qwen3_decoder_layer_impl.cpp | 153 ++++----- .../layers/npu/npu_qwen3_decoder_layer_impl.h | 3 + xllm/models/CMakeLists.txt | 2 + xllm/models/lazy_layer_loader.cpp | 90 ++++++ xllm/models/lazy_layer_loader.h | 90 ++++++ xllm/models/llm/llm_model_base.h | 13 + 15 files changed, 850 insertions(+), 308 deletions(-) create mode 100644 xllm/models/lazy_layer_loader.cpp create mode 100644 xllm/models/lazy_layer_loader.h diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index e5fd7c348..16fd58e30 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -51,7 +51,7 @@ HFModelLoader::HFModelLoader(const std::string& model_weights_path) << "Failed to find model weights files in " << model_weights_path; // sort the model weights files by name std::sort(model_weights_files_.begin(), model_weights_files_.end()); - threadpool_ = std::make_unique(32); + threadpool_ = std::make_unique(64); } std::unique_ptr HFModelLoader::tokenizer() const { diff --git a/xllm/core/layers/base_layer.cpp b/xllm/core/layers/base_layer.cpp index 16b7477e0..cdc81f1db 100755 --- a/xllm/core/layers/base_layer.cpp +++ b/xllm/core/layers/base_layer.cpp @@ -85,12 +85,18 @@ void BaseLayer::correct_tensor_dtype(torch::Tensor& tensor, void BaseLayer::set_weight(const StateDict& state_dict, const std::string& tensor_name, - int weight_position) { + int weight_position, + bool to_host) { + auto device = to_host ? at::kCPU : device_; for (const auto& [name, tensor] : state_dict) { if (absl::EndsWith(name, tensor_name)) { at::Tensor mutable_tensor = tensor; correct_tensor_dtype(mutable_tensor, tensor_name); - at_weight_tensors_[weight_position] = mutable_tensor.to(device_); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } } } } @@ -98,22 +104,35 @@ void BaseLayer::set_weight(const StateDict& state_dict, void BaseLayer::set_weight(const StateDict& state_dict, const std::string& tensor_name, int weight_position, - int dim) { - for (const auto& [name, tensor] : state_dict) { - if (absl::EndsWith(name, tensor_name)) { - if (parallel_args_.world_size() <= 1) { + int dim, + bool to_host) { + auto device = to_host ? at::kCPU : device_; + if (parallel_args_.world_size() <= 1) { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { at::Tensor mutable_tensor = tensor; correct_tensor_dtype(mutable_tensor, tensor_name); - at_weight_tensors_[weight_position] = mutable_tensor.to(device_); - } else { - at_weight_tensors_[weight_position] = - state_dict - .get_sharded_tensor(tensor_name, - /*dim=*/dim, - /*rank=*/parallel_args_.rank(), - /*world_size=*/parallel_args_.world_size()) - .to(device_); - correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } + } + } + } else { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + at::Tensor mutable_tensor = state_dict.get_sharded_tensor( + tensor_name, + /*dim=*/dim, + /*rank=*/parallel_args_.rank(), + /*world_size=*/parallel_args_.world_size()); + correct_tensor_dtype(mutable_tensor, tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } } } } @@ -124,26 +143,38 @@ void BaseLayer::set_weight(const StateDict& state_dict, int weight_position, int dim, int rank, - int world_size) { - for (const auto& [name, tensor] : state_dict) { - if (absl::EndsWith(name, tensor_name)) { - if (world_size <= 1) { + int world_size, + bool to_host) { + auto device = to_host ? at::kCPU : device_; + if (world_size <= 1) { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { at::Tensor mutable_tensor = tensor; correct_tensor_dtype(mutable_tensor, tensor_name); - at_weight_tensors_[weight_position] = mutable_tensor.to(device_); - } else { - at_weight_tensors_[weight_position] = - state_dict - .get_sharded_tensor(tensor_name, - /*dim=*/dim, - /*rank=*/rank, - /*world_size=*/world_size) - .to(device_); - correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } + } + } + } else { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + at::Tensor mutable_tensor = + state_dict.get_sharded_tensor(tensor_name, + /*dim=*/dim, + /*rank=*/rank, + /*world_size=*/world_size); + correct_tensor_dtype(mutable_tensor, tensor_name); + if (to_host) { + at_host_weight_tensors_[weight_position] = mutable_tensor.to(device); + } else { + at_weight_tensors_[weight_position] = mutable_tensor.to(device); + } } } } } - } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/base_layer.h b/xllm/core/layers/base_layer.h index 3691d1fc0..903e5463d 100644 --- a/xllm/core/layers/base_layer.h +++ b/xllm/core/layers/base_layer.h @@ -103,18 +103,21 @@ class BaseLayer : public torch::nn::Module { void set_weight(const StateDict& state_dict, const std::string& tensor_name, int weight_position, - int dim); + int dim, + bool to_host = false); void set_weight(const StateDict& state_dict, const std::string& tensor_name, - int weight_position); + int weight_position, + bool to_host = false); void set_weight(const StateDict& state_dict, const std::string& tensor_name, int weight_position, int dim, int rank, - int world_size); + int world_size, + bool to_host = false); virtual void run_task(std::string taskName, std::function task) const { }; @@ -126,6 +129,7 @@ class BaseLayer : public torch::nn::Module { protected: std::vector at_weight_tensors_; + std::vector at_host_weight_tensors_; at::Device device_; std::string name_; torch::ScalarType dtype_; diff --git a/xllm/core/layers/npu/npu_base_layer.cpp b/xllm/core/layers/npu/npu_base_layer.cpp index 72165dde9..9b598bbe7 100644 --- a/xllm/core/layers/npu/npu_base_layer.cpp +++ b/xllm/core/layers/npu/npu_base_layer.cpp @@ -27,11 +27,22 @@ limitations under the License. namespace xllm { namespace layer { +namespace { +inline size_t AlignUp(size_t value, size_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} +} // namespace + NpuBaseLayer::NpuBaseLayer(const ModelContext& context) : BaseLayer(context) { context_ = const_cast(context.get_atb_context()); work_space_ = AtbWorkspace(device_); } +NpuBaseLayer::~NpuBaseLayer() { + release_host_storage(); + release_device_storage(); +} + atb::Status NpuBaseLayer::execute_node(atb_speed::Model::Node& node, int node_id, aclrtEvent* event, @@ -124,6 +135,173 @@ void NpuBaseLayer::run_task(std::string taskName, cmd.Run(); } +void NpuBaseLayer::init_weight_slices(int weight_count) { + weight_slices_.resize(weight_count); + size_t offset = 0; + for (size_t i = 0; i < weight_count; ++i) { + weight_slices_[i] = {}; + const auto& tensor = at_host_weight_tensors_[i]; + if (!tensor.defined() || tensor.numel() < 1) { + continue; + } + // offset = AlignUp(offset, kHostAlignment); + weight_slices_[i].offset = offset; + weight_slices_[i].bytes = tensor.nbytes(); + weight_slices_[i].sizes = tensor.sizes().vec(); + weight_slices_[i].dtype = tensor.scalar_type(); + offset += weight_slices_[i].bytes; + } + size_t max_alignment = std::max(kHostAlignment, kDeviceAlignment); + // storage_size_ = AlignUp(offset, max_alignment); + LOG(INFO) << "NpuBaseLayer total weight size: " << offset; + storage_size_ = offset; +} + +void NpuBaseLayer::copy_weights_to_pinned_host() { + CHECK_GT(storage_size_, 0) << "model size must be greater than 0."; + CHECK_EQ(weight_slices_.size(), at_host_weight_tensors_.size()) + << "weight_slices_ size and at_host_weight_tensors_ size mismatch."; + + size_t max_alignment = std::max(kHostAlignment, kDeviceAlignment); + storage_size_ = AlignUp(storage_size_, max_alignment); + + auto ret = aclrtMallocHost(&host_pinned_storage_, storage_size_); + CHECK_EQ(ret, ACL_SUCCESS) + << "Failed to allocate pinned host storage size=" << storage_size_; + + for (size_t i = 0; i < weight_slices_.size(); ++i) { + const auto& slice = weight_slices_[i]; + if (!slice.bytes) { + continue; + } + auto host_tensor = at_host_weight_tensors_[i].to(torch::kCPU).contiguous(); + void* dst = static_cast(host_pinned_storage_) + + static_cast(slice.offset); + std::memcpy(dst, host_tensor.data_ptr(), slice.bytes); + at_host_weight_tensors_[i] = at::Tensor(); + } + + ret = aclrtMallocAlign32( + &device_storage_, storage_size_, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_EQ(ret, ACL_SUCCESS) + << "Failed to allocate contiguous device storage size=" << storage_size_; +} + +void NpuBaseLayer::copy_weights_to_device() { + CHECK_EQ(weight_slices_.size(), at_host_weight_tensors_.size()) + << "weight_slices_ size and at_host_weight_tensors_ size mismatch."; + auto ret = aclrtMallocAlign32( + &device_storage_, storage_size_, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_EQ(ret, ACL_SUCCESS) + << "Failed to allocate contiguous device storage size=" << storage_size_; + + for (size_t i = 0; i < weight_slices_.size(); ++i) { + const auto& slice = weight_slices_[i]; + if (!slice.bytes) { + continue; + } + void* dst = static_cast(device_storage_) + + static_cast(slice.offset); + auto host_tensor = at_host_weight_tensors_[i].contiguous(); + auto err = aclrtMemcpy(dst, + slice.bytes, + host_tensor.data_ptr(), + slice.bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_EQ(err, ACL_SUCCESS) << "aclrtMemcpy failed for tensor index " << i; + at_host_weight_tensors_[i] = at::Tensor(); + } +} + +torch::Tensor NpuBaseLayer::convert_to_torch_tensor( + const std::vector& dims, + const torch::ScalarType dtype, + const uintptr_t& dev_addr, + int acl_format) { + c10::DeviceType device_type = c10::DeviceType::PrivateUse1; + torch::TensorOptions option = + torch::TensorOptions().dtype(dtype).device(device_type); + + auto tensor = torch::empty({0}, option); + auto address = reinterpret_cast(dev_addr); + torch::DataPtr c10_data_ptr(address, address, [](void*) {}, tensor.device()); + + size_t tensor_nbytes = at::detail::computeStorageNbytesContiguous( + dims, tensor.dtype().itemsize()); + torch::Storage storage; + // get npu storage constructor from register and construct storage + auto fptr = c10::GetStorageImplCreate(device_type); + auto allocator = c10::GetAllocator(device_type); + storage = fptr(c10::StorageImpl::use_byte_size_t(), 0, allocator, true); + storage.unsafeGetStorageImpl()->set_nbytes(tensor_nbytes); + storage.set_data_ptr(std::move(c10_data_ptr)); + + tensor.set_(storage, 0, dims); + // cast npu format to nd + tensor = at_npu::native::npu_format_cast(tensor, acl_format); + + return tensor; +} + +void NpuBaseLayer::init_atb_tensors() { + for (size_t i = 0; i < weight_slices_.size(); ++i) { + const auto& slice = weight_slices_[i]; + if (!slice.bytes) { + continue; + } + void* base = static_cast(device_storage_) + + static_cast(slice.offset); + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base)); + } + + c10_npu::NPUCachingAllocator::emptyCache(); + + for (size_t i = 0; i < weight_slices_.size(); ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } +} + +void NpuBaseLayer::copy_weights_to_device_async() { + CHECK_EQ(weight_slices_.size(), at_weight_tensors_.size()) + << "weight_slices_ size and at_weight_tensors_ size mismatch."; + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + void* dst = static_cast(device_storage_); + void* src = static_cast(host_pinned_storage_); + + auto err = aclrtMemcpyAsync(dst, + storage_size_, + src, + storage_size_, + ACL_MEMCPY_HOST_TO_DEVICE, + stream); + CHECK_EQ(err, ACL_SUCCESS) << "aclrtMemcpyAsync failed"; +} + +void NpuBaseLayer::release_device_storage() { + if (device_storage_ == nullptr) { + return; + } + auto ret = aclrtFree(device_storage_); + if (ret != ACL_SUCCESS) { + LOG(ERROR) << "Failed to free contiguous layer storage, ret=" << ret; + } + device_storage_ = nullptr; +} + +void NpuBaseLayer::release_host_storage() { + if (host_pinned_storage_ == nullptr) { + return; + } + auto ret = aclrtFreeHost(host_pinned_storage_); + if (ret != ACL_SUCCESS) { + LOG(ERROR) << "Failed to free pinned host storage, ret=" << ret; + } + host_pinned_storage_ = nullptr; +} + atb::Tensor NpuBaseLayer::XTensor2Tensor( const std::shared_ptr& xtensor) { static std::map dtypeMap = { diff --git a/xllm/core/layers/npu/npu_base_layer.h b/xllm/core/layers/npu/npu_base_layer.h index cb87e2274..f81b81007 100644 --- a/xllm/core/layers/npu/npu_base_layer.h +++ b/xllm/core/layers/npu/npu_base_layer.h @@ -57,7 +57,7 @@ namespace layer { class NpuBaseLayer : public BaseLayer { public: explicit NpuBaseLayer(const ModelContext& context); - ~NpuBaseLayer() = default; + ~NpuBaseLayer() override; atb::Status execute_node(atb_speed::Model::Node& node, int nodeId = 0, @@ -72,15 +72,42 @@ class NpuBaseLayer : public BaseLayer { virtual void run_task(std::string taskName, std::function task) const override; + void init_weight_slices(int weight_count); + + void copy_weights_to_pinned_host(); + + void copy_weights_to_device(); + + void copy_weights_to_device_async(); + + virtual void init_atb_tensors(); + protected: atb::Tensor XTensor2Tensor(const std::shared_ptr& xtensor); protected: + struct WeightSlice { + size_t offset = 0; + size_t bytes = 0; + std::vector sizes; + torch::ScalarType dtype = torch::kFloat16; + }; + void* host_pinned_storage_ = nullptr; + void* device_storage_ = nullptr; + size_t storage_size_ = 0; + std::vector weight_slices_; atb::Context* context_; AtbWorkspace work_space_; - // std::vector at_weight_tensors_; std::vector atb_weight_tensors_; bool graph_captured_{false}; + static constexpr size_t kDeviceAlignment = 64; + static constexpr size_t kHostAlignment = 64; + void release_device_storage(); + void release_host_storage(); + torch::Tensor convert_to_torch_tensor(const std::vector& dims, + const torch::ScalarType dtype, + const uintptr_t& dev_addr, + int acl_format = 2); }; } // namespace layer diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index f321d1401..c77a8f6d8 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -296,6 +296,7 @@ void NpuDeepseekV2DecoderLayerImpl::initialize_tensors( const torch::TensorOptions& options) { // initializ placeholder at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + at_host_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; int_tensor_placeholder_ = torch::ones({1}).to(torch::kInt32).to(device_); @@ -374,7 +375,6 @@ void NpuDeepseekV2DecoderLayerImpl::reserve_experts_weights( for (const auto& weight_name : weight_names) { experts_weights_[weight_name] = std::vector(num_of_device_experts); - ; } } @@ -382,6 +382,7 @@ void NpuDeepseekV2DecoderLayerImpl::initialize_weight_tensors( const torch::TensorOptions& options) { for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { at_weight_tensors_[i] = torch::zeros({1}).to(options); + at_host_weight_tensors_[i] = torch::zeros({1}).to(torch::kFloat16); } if (FLAGS_enable_eplb) { const int64_t size = @@ -646,7 +647,6 @@ void NpuDeepseekV2DecoderLayerImpl::load_state_dict( for (const auto& [name, tensor] : state_dict) { bool is_sharded = false; int index = 0; - if (absl::EndsWith(name, "self_attn.kv_b_proj.weight")) { index = WEIGHT_MAPPING_W8A8.at(name); set_kv_weight(state_dict, name, index, WEIGHT_SHARD_W8A8.at(index)); @@ -788,16 +788,15 @@ void NpuDeepseekV2DecoderLayerImpl::process_shared_expert_weights( return; } if (FLAGS_expert_parallel_degree == 2) { - tmp_tensor = tensor.to(device_); + tmp_tensor = tensor; } else { const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); tmp_tensor = is_sharded ? get_sharded_tensor( state_dict, name, WEIGHT_SHARD_W8A8.at(index)) - .to(device_) - : tensor.to(device_); + : tensor; } if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; + at_host_weight_tensors_[index] = tmp_tensor; } else { shared_experts_weights_[name] = tmp_tensor; } @@ -820,10 +819,9 @@ void NpuDeepseekV2DecoderLayerImpl::process_mlp_common_weights( WEIGHT_SHARD_W8A8.at(index), dp_local_tp_rank_, dp_local_tp_size_) - .to(device_) - : tensor.to(device_); + : tensor; if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; + at_host_weight_tensors_[index] = tmp_tensor; } else { shared_experts_weights_[name] = tmp_tensor; } @@ -845,11 +843,10 @@ void NpuDeepseekV2DecoderLayerImpl::process_general_weights( WEIGHT_SHARD_W8A8.at(index), dp_local_tp_rank_, dp_local_tp_size_) - .to(device_) - : tensor.to(device_); + : tensor; correct_tensor_dtype(tmp_tensor, name); - at_weight_tensors_[index] = tmp_tensor; + at_host_weight_tensors_[index] = tmp_tensor; } void NpuDeepseekV2DecoderLayerImpl::set_kv_weight( @@ -859,16 +856,12 @@ void NpuDeepseekV2DecoderLayerImpl::set_kv_weight( int dim) { torch::Tensor mutable_tensor; if (parallel_args_.world_size() <= 1) { - mutable_tensor = state_dict.get_tensor(tensor_name).to(device_); - correct_tensor_dtype(mutable_tensor, tensor_name); + mutable_tensor = state_dict.get_tensor(tensor_name); } else { - mutable_tensor = - get_sharded_tensor( - state_dict, tensor_name, dim, dp_local_tp_rank_, dp_local_tp_size_) - .to(device_); - // mutable_tensor = get_sharded_tensor(state_dict, tensor_name, dim); - correct_tensor_dtype(mutable_tensor, tensor_name); + mutable_tensor = get_sharded_tensor( + state_dict, tensor_name, dim, dp_local_tp_rank_, dp_local_tp_size_); } + correct_tensor_dtype(mutable_tensor, tensor_name); torch::Tensor kv_b_proj_weight = mutable_tensor.reshape({num_key_value_heads_ / dp_local_tp_size_, @@ -881,8 +874,8 @@ void NpuDeepseekV2DecoderLayerImpl::set_kv_weight( .slice(1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_) .transpose(1, 2) .contiguous(); - at_weight_tensors_[weight_position] = k_b_proj_preprocessed.to(device_); - at_weight_tensors_[weight_position + 6] = v_b_proj_preprocessed.to(device_); + at_host_weight_tensors_[weight_position] = k_b_proj_preprocessed; + at_host_weight_tensors_[weight_position + 6] = v_b_proj_preprocessed; } void NpuDeepseekV2DecoderLayerImpl::preprocess_linear_for_rope() { @@ -893,13 +886,14 @@ void NpuDeepseekV2DecoderLayerImpl::preprocess_linear_for_rope() { } } int index = WEIGHT_MAPPING_W8A8.at(name); - at_weight_tensors_[index] = - view_tensor(at_weight_tensors_[index], name, true); - at_weight_tensors_[index] = trans_rope_weight(at_weight_tensors_[index]); - at_weight_tensors_[index] = + at_host_weight_tensors_[index] = + view_tensor(at_host_weight_tensors_[index], name, true); + at_host_weight_tensors_[index] = + trans_rope_weight(at_host_weight_tensors_[index]); + at_host_weight_tensors_[index] = (!absl::EndsWith(name, "weight")) - ? view_tensor(at_weight_tensors_[index], name, false).flatten() - : view_tensor(at_weight_tensors_[index], name, false); + ? view_tensor(at_host_weight_tensors_[index], name, false).flatten() + : view_tensor(at_host_weight_tensors_[index], name, false); } } @@ -930,8 +924,10 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::view_tensor( torch::Tensor NpuDeepseekV2DecoderLayerImpl::trans_rope_weight( torch::Tensor weight) { + auto new_weight = weight.clone(); int64_t d = weight.size(-2); int64_t rope_dim = prefill_param_.qkRopeHeadDim; + torch::Tensor weight_1 = weight.slice(-2, d - rope_dim, torch::indexing::None, 2).contiguous(); @@ -939,10 +935,9 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::trans_rope_weight( weight.slice(-2, d - rope_dim + 1, torch::indexing::None, 2).contiguous(); torch::Tensor combined = torch::cat({weight_1, weight_2}, -2); + new_weight.slice(-2, d - rope_dim, d).copy_(combined); - weight.slice(-2, d - rope_dim, d).copy_(combined); - - return weight.contiguous(); + return new_weight.contiguous(); } torch::Tensor NpuDeepseekV2DecoderLayerImpl::get_sharded_tensor( @@ -1011,7 +1006,70 @@ void NpuDeepseekV2DecoderLayerImpl::verify_loaded_weights( } } +void NpuDeepseekV2DecoderLayerImpl::merge_and_move_pinned_host() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_pinned_host(); + init_atb_tensors(); + init_layer(); +} + void NpuDeepseekV2DecoderLayerImpl::merge_loaded_weights() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_device(); + init_atb_tensors(); + init_layer(); +} + +void NpuDeepseekV2DecoderLayerImpl::init_atb_tensors() { + for (size_t i = 0; i < weight_slices_.size(); ++i) { + const auto& slice = weight_slices_[i]; + if (!slice.bytes) { + continue; + } + void* base = static_cast(device_storage_) + + static_cast(slice.offset); + if (layer_id_ >= prefill_param_.firstKDenseReplace) { + if (i == IN_MLP_GATEUP_WEIGHT_EXPERT) { + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base), 29); + continue; + } else if (i == IN_MLP_DOWN_WEIGHT_EXPERT) { +#if defined(USE_A3) + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base), 29); +#else + if (decode_param_.isBF16) { + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base), 29); + } else { + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base)); + } +#endif + continue; + } + } + if (i == IN_Q_PROJ_A_WEIGHT || i == IN_Q_PROJ_B_WEIGHT) { + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base), 29); + continue; + } else { + at_weight_tensors_[i] = convert_to_torch_tensor( + slice.sizes, slice.dtype, reinterpret_cast(base)); + } + } + + c10_npu::NPUCachingAllocator::emptyCache(); + + for (size_t i = 0; i < weight_slices_.size(); ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } +} + +void NpuDeepseekV2DecoderLayerImpl::merge_loaded_at_weights() { if (quantize_type_ == "w8a8_dynamic") { if (prefill_param_.isBF16) { convert_descaled_weights_to_float(); @@ -1019,7 +1077,6 @@ void NpuDeepseekV2DecoderLayerImpl::merge_loaded_weights() { convert_offsets_to_int8(); handle_device_specific_bias(); } - merge_shared_experts_weights(); if (layer_id_ >= prefill_param_.firstKDenseReplace) { merge_experts_weights(); @@ -1029,86 +1086,85 @@ void NpuDeepseekV2DecoderLayerImpl::merge_loaded_weights() { preprocess_linear_for_rope(); - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT], - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT]}, + at_host_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = + torch::cat({at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT], + at_host_weight_tensors_[IN_Q_PROJ_A_WEIGHT]}, 0) .contiguous(); + if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_Q_PROJ_A_BIAS] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS], - at_weight_tensors_[IN_Q_PROJ_A_BIAS]}, + at_host_weight_tensors_[IN_Q_PROJ_A_BIAS] = + torch::cat({at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS], + at_host_weight_tensors_[IN_Q_PROJ_A_BIAS]}, 0) .contiguous(); - at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE], - at_weight_tensors_[IN_Q_PROJ_A_DESCALE]}, + at_host_weight_tensors_[IN_Q_PROJ_A_DESCALE] = + torch::cat({at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE], + at_host_weight_tensors_[IN_Q_PROJ_A_DESCALE]}, 0) .contiguous(); } - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = at_npu::native::npu_format_cast( - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT], 29); - at_weight_tensors_[IN_Q_PROJ_B_WEIGHT] = at_npu::native::npu_format_cast( - at_weight_tensors_[IN_Q_PROJ_B_WEIGHT], 29); - - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_OFFSET] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_SCALE] = tensor_placeholder_; + at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT] = + torch::zeros({1}).to(torch::kFloat16); + at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS] = + torch::zeros({1}).to(torch::kFloat16); + at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE] = + torch::zeros({1}).to(torch::kFloat16); + at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_OFFSET] = + torch::zeros({1}).to(torch::kFloat16); + at_host_weight_tensors_[IN_KV_PROJ_WITH_MQA_SCALE] = + torch::zeros({1}).to(torch::kFloat16); if (FLAGS_expert_parallel_degree != 2) { - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT], + at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = + torch::roll(at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT], {-1 * ep_rank_ * num_experts_per_partition_}, {0}) .contiguous(); - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS] = - torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS], + + at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS] = + torch::roll(at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS], {-1 * ep_rank_ * num_experts_per_partition_}, {0}) .contiguous(); } + // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT] = // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT].transpose(0, 1); - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); + at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = + at_host_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to( + torch::kFloat32); if (quantize_type_ == "w8a8_dynamic") { - // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); if (!prefill_param_.isBF16) { - at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_A_DESCALE]); - at_weight_tensors_[IN_Q_PROJ_B_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_B_DESCALE]); - at_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_ATTENTION_OUT_DESCALE]); - - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].to( + at_host_weight_tensors_[IN_Q_PROJ_A_DESCALE] = + convert_fp16_to_int64(at_host_weight_tensors_[IN_Q_PROJ_A_DESCALE]); + at_host_weight_tensors_[IN_Q_PROJ_B_DESCALE] = + convert_fp16_to_int64(at_host_weight_tensors_[IN_Q_PROJ_B_DESCALE]); + at_host_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = convert_fp16_to_int64( + at_host_weight_tensors_[IN_ATTENTION_OUT_DESCALE]); + + at_host_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = + at_host_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].to( torch::kFloat16); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].to( + at_host_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = + at_host_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].to( + torch::kFloat32); + at_host_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = + at_host_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].to( torch::kFloat32); - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].to( + at_host_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = + at_host_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].to( + torch::kFloat16); + at_host_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = + at_host_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].to( torch::kFloat32); - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].to(torch::kFloat16); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].to(torch::kFloat32); - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT].to(torch::kFloat16); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT].to(torch::kFloat32); + at_host_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = + at_host_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT].to( + torch::kFloat16); + at_host_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = + at_host_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT].to(torch::kFloat32); } } - c10_npu::NPUCachingAllocator::emptyCache(); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); - } - init_layer(); } torch::Tensor NpuDeepseekV2DecoderLayerImpl::convert_fp16_to_int64( @@ -1121,7 +1177,8 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::convert_fp16_to_int64( void NpuDeepseekV2DecoderLayerImpl::convert_descaled_weights_to_float() { auto convert_to_float = [this](int index) { - at_weight_tensors_[index] = at_weight_tensors_[index].to(torch::kFloat32); + at_host_weight_tensors_[index] = + at_host_weight_tensors_[index].to(torch::kFloat32); }; convert_to_float(IN_Q_PROJ_A_DESCALE); convert_to_float(IN_Q_PROJ_B_DESCALE); @@ -1131,8 +1188,8 @@ void NpuDeepseekV2DecoderLayerImpl::convert_descaled_weights_to_float() { void NpuDeepseekV2DecoderLayerImpl::convert_offsets_to_int8() { auto convert_to_int8 = [this](int index) { - at_weight_tensors_[index] = - at_weight_tensors_[index].to(torch::kInt8).to(device_); + at_host_weight_tensors_[index] = + at_host_weight_tensors_[index].to(torch::kInt8); }; convert_to_int8(IN_Q_PROJ_A_OFFSET); convert_to_int8(IN_Q_PROJ_B_OFFSET); @@ -1142,8 +1199,9 @@ void NpuDeepseekV2DecoderLayerImpl::convert_offsets_to_int8() { void NpuDeepseekV2DecoderLayerImpl::handle_device_specific_bias() { if (dp_local_tp_rank_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = + torch::Tensor original_tensor = + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros(original_tensor.sizes(), torch::TensorOptions() .dtype(original_tensor.dtype()) @@ -1155,10 +1213,8 @@ void NpuDeepseekV2DecoderLayerImpl::merge_shared_experts_weights() { auto merge_and_clear = [this](int index, torch::Tensor& shared_experts_gate, torch::Tensor& shared_experts_up) { - at_weight_tensors_[index] = - torch::cat({shared_experts_gate, shared_experts_up}, 0) - .to(device_) - .contiguous(); + at_host_weight_tensors_[index] = + torch::cat({shared_experts_gate, shared_experts_up}, 0).contiguous(); shared_experts_gate = tensor_placeholder_; shared_experts_up = tensor_placeholder_; }; @@ -1197,55 +1253,51 @@ void NpuDeepseekV2DecoderLayerImpl::merge_experts_weights() { torch::Tensor mlp_gateup_weight = merge_experts_weights(experts_weights_["gate_proj.weight"], experts_weights_["up_proj.weight"], - device_, + at::kCPU, /*transpose=*/true); - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_gateup_weight, 29); + at_host_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = mlp_gateup_weight; // at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = // at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = + at_host_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = merge_experts_weights(experts_weights_["gate_proj.weight_offset"], experts_weights_["up_proj.weight_offset"], - device_); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = + at::kCPU); + at_host_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = merge_experts_weights(experts_weights_["gate_proj.weight_scale"], experts_weights_["up_proj.weight_scale"], - device_); + at::kCPU); } #if defined(USE_A3) torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], - device_, + at::kCPU, /*transpose=*/false); - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - // at_npu::native::npu_format_cast(mlp_down_weight, 29); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + at_host_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + mlp_down_weight.contiguous(); #else // TODO: xllm ops's GMM need to support MTP. if (decode_param_.isBF16 && false) { torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], - device_, + at::kCPU, /*transpose=*/true); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 29); + at_host_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = mlp_down_weight; } else { torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], - device_, + at::kCPU, /*transpose=*/false); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + at_host_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + mlp_down_weight.contiguous(); } #endif if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( - experts_weights_["down_proj.weight_offset"], device_); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = merge_experts_weights( - experts_weights_["down_proj.weight_scale"], device_); + at_host_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_offset"], at::kCPU); + at_host_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_scale"], at::kCPU); } } @@ -1452,8 +1504,8 @@ void NpuDeepseekV2DecoderLayerImpl::update_expert_weight() { void NpuDeepseekV2DecoderLayerImpl::squeeze_experts_weights() { for (const auto& index : SQUEEZE_WEIGHT_VEC) { - if (at_weight_tensors_[index].dim() > 1) { - at_weight_tensors_[index] = at_weight_tensors_[index].squeeze(); + if (at_host_weight_tensors_[index].dim() > 1) { + at_host_weight_tensors_[index] = at_host_weight_tensors_[index].squeeze(); } } } diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index c57964882..627c54061 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -56,7 +56,6 @@ class ExpertBuffer { const torch::TensorOptions& weight_options, const torch::TensorOptions& offset_options, const torch::TensorOptions& scale_options, - bool force_reinit = false) { std::lock_guard lock(mutex_); @@ -124,6 +123,8 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { void update_expert_weight(); + void merge_and_move_pinned_host(); + virtual int64_t init_layer() override; torch::Tensor forward(torch::Tensor& x, @@ -143,6 +144,10 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { bool use_dp_sharding = false; }; + virtual void init_atb_tensors() override; + + void merge_loaded_at_weights(); + void initialize_tensors(const torch::TensorOptions& options); void initialize_weight_tensors(const torch::TensorOptions& options); diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index ffbf45bd0..bef858687 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -164,6 +164,7 @@ NpuQwen2DecoderLayerImpl::NpuQwen2DecoderLayerImpl(const ModelContext& context) param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + at_host_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -182,7 +183,7 @@ NpuQwen2DecoderLayerImpl::NpuQwen2DecoderLayerImpl(const ModelContext& context) void NpuQwen2DecoderLayerImpl::verify_loaded_weights() const { for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + CHECK(at_host_weight_tensors_[index].sizes() != std::vector({1})) << "weight is not loaded for " << name; } } @@ -199,121 +200,148 @@ TransposeType NpuQwen2DecoderLayerImpl::check_transpose(at::Tensor& tensor) { } void NpuQwen2DecoderLayerImpl::merge_loaded_weights() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_device(); + init_attn_mask(); + init_atb_tensors(); + init_layer(); +} + +void NpuQwen2DecoderLayerImpl::merge_and_move_pinned_host() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_pinned_host(); + init_attn_mask(); + init_atb_tensors(); + init_layer(); +} + +void NpuQwen2DecoderLayerImpl::merge_loaded_at_weights() { + auto make_zero_like = [](const torch::Tensor& ref) { + return torch::zeros( + {1}, + torch::TensorOptions().dtype(ref.scalar_type()).device(torch::kCPU)); + }; + if (quantize_type_ == "w8a8") { - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); - at_weight_tensors_[IN_Q_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], - at_weight_tensors_[IN_K_DEQSCALE], - at_weight_tensors_[IN_V_DEQSCALE]}, + at_host_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = + at_host_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); + at_host_weight_tensors_[IN_Q_DEQSCALE] = + torch::cat({at_host_weight_tensors_[IN_Q_DEQSCALE], + at_host_weight_tensors_[IN_K_DEQSCALE], + at_host_weight_tensors_[IN_V_DEQSCALE]}, 0) .to(torch::kFloat32); - at_weight_tensors_[IN_K_DEQSCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_DEQSCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_K_OFFSET] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_OFFSET] = torch::zeros({1}).to(device_); - - at_weight_tensors_[IN_K_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W2_BIAS] = - torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], - at_weight_tensors_[IN_MLP_W1_BIAS]}, + at_host_weight_tensors_[IN_K_DEQSCALE] = + make_zero_like(at_host_weight_tensors_[IN_K_DEQSCALE]); + at_host_weight_tensors_[IN_V_DEQSCALE] = + make_zero_like(at_host_weight_tensors_[IN_V_DEQSCALE]); + at_host_weight_tensors_[IN_K_OFFSET] = + make_zero_like(at_host_weight_tensors_[IN_K_OFFSET]); + at_host_weight_tensors_[IN_V_OFFSET] = + make_zero_like(at_host_weight_tensors_[IN_V_OFFSET]); + at_host_weight_tensors_[IN_K_SCALE] = + make_zero_like(at_host_weight_tensors_[IN_K_SCALE]); + at_host_weight_tensors_[IN_V_SCALE] = + make_zero_like(at_host_weight_tensors_[IN_V_SCALE]); + at_host_weight_tensors_[IN_MLP_W2_BIAS] = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_BIAS], + at_host_weight_tensors_[IN_MLP_W1_BIAS]}, 0); - at_weight_tensors_[IN_MLP_W1_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W2_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], - at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, + at_host_weight_tensors_[IN_MLP_W1_BIAS] = + make_zero_like(at_host_weight_tensors_[IN_MLP_W1_BIAS]); + at_host_weight_tensors_[IN_MLP_W2_DEQSCALE] = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_DEQSCALE], + at_host_weight_tensors_[IN_MLP_W1_DEQSCALE]}, 0) .to(torch::kFloat32); - at_weight_tensors_[IN_MLP_W1_DEQSCALE] = torch::zeros({1}).to(device_); - - at_weight_tensors_[IN_MLP_W1_OFFSET] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W1_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_Q_OFFSET] = - at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] - .to(torch::kInt8) - .to(device_); - at_weight_tensors_[IN_MLP_W2_OFFSET] = - at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); + at_host_weight_tensors_[IN_MLP_W1_DEQSCALE] = + make_zero_like(at_host_weight_tensors_[IN_MLP_W1_DEQSCALE]); + at_host_weight_tensors_[IN_MLP_W1_OFFSET] = + make_zero_like(at_host_weight_tensors_[IN_MLP_W1_OFFSET]); + at_host_weight_tensors_[IN_MLP_W1_SCALE] = + make_zero_like(at_host_weight_tensors_[IN_MLP_W1_SCALE]); + at_host_weight_tensors_[IN_Q_OFFSET] = + at_host_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8); + at_host_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = + at_host_weight_tensors_[IN_ATTENTION_OUT_OFFSET].to(torch::kInt8); + at_host_weight_tensors_[IN_MLP_W2_OFFSET] = + at_host_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8); if (device_id_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + torch::Tensor original_tensor = + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS]; auto shape = original_tensor.sizes(); auto dtype = original_tensor.dtype(); - auto device = original_tensor.device(); - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( - shape, torch::TensorOptions().dtype(dtype).device(device)); + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( + shape, torch::TensorOptions().dtype(dtype).device(torch::kCPU)); } } - auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], - at_weight_tensors_[IN_K_WEIGHT], - at_weight_tensors_[IN_V_WEIGHT]}, + auto new_q_weight = torch::cat({at_host_weight_tensors_[IN_Q_WEIGHT], + at_host_weight_tensors_[IN_K_WEIGHT], + at_host_weight_tensors_[IN_V_WEIGHT]}, 0); - at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; + at_host_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; - at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); + at_host_weight_tensors_[IN_K_WEIGHT] = + make_zero_like(at_host_weight_tensors_[IN_K_WEIGHT]); + at_host_weight_tensors_[IN_V_WEIGHT] = + make_zero_like(at_host_weight_tensors_[IN_V_WEIGHT]); - auto new_q_bias = torch::cat({at_weight_tensors_[IN_Q_BIAS], - at_weight_tensors_[IN_K_BIAS], - at_weight_tensors_[IN_V_BIAS]}, + auto new_q_bias = torch::cat({at_host_weight_tensors_[IN_Q_BIAS], + at_host_weight_tensors_[IN_K_BIAS], + at_host_weight_tensors_[IN_V_BIAS]}, 0); - at_weight_tensors_[IN_Q_BIAS] = new_q_bias; + at_host_weight_tensors_[IN_Q_BIAS] = new_q_bias; - at_weight_tensors_[IN_K_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_BIAS] = torch::zeros({1}).to(device_); + at_host_weight_tensors_[IN_K_BIAS] = + make_zero_like(at_host_weight_tensors_[IN_K_BIAS]); + at_host_weight_tensors_[IN_V_BIAS] = + make_zero_like(at_host_weight_tensors_[IN_V_BIAS]); TransposeType transpose_type = - check_transpose(at_weight_tensors_[IN_MLP_W2_WEIGHT]); + check_transpose(at_host_weight_tensors_[IN_MLP_W2_WEIGHT]); int transpose_value = static_cast(transpose_type); prefill_param_.linearTransposeType[4] = transpose_value; decode_param_.linearTransposeType[4] = transpose_value; if (transpose_type == TransposeType::TRANSPOSE) { - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); + auto new_mlp_weight = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_WEIGHT], + at_host_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0); + at_host_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); } else { - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0) - .transpose(0, 1); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); - } - - at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); - - c10_npu::NPUCachingAllocator::emptyCache(); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + auto new_mlp_weight = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_WEIGHT], + at_host_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0) + .transpose(0, 1); + at_host_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); } - init_layer(); + at_host_weight_tensors_[IN_MLP_W1_WEIGHT] = + make_zero_like(at_host_weight_tensors_[IN_MLP_W1_WEIGHT]); } void NpuQwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { if (quantize_type_ == "w8a8") { for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); + set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index], true); } else { - set_weight(state_dict, name, index); + set_weight(state_dict, name, index, true); } } - at_weight_tensors_[IN_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_NORM_WEIGHT].options()) - .to(device_); + at_host_weight_tensors_[IN_NORM_BIAS] = + torch::zeros(at_host_weight_tensors_[IN_NORM_WEIGHT].sizes(), + at_host_weight_tensors_[IN_NORM_WEIGHT].options()); - at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) - .to(device_); + at_host_weight_tensors_[IN_SELFOUT_NORM_BIAS] = + torch::zeros(at_host_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), + at_host_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()); prefill_param_.packQuantType = {static_cast(PackType::ALL_W8A8), static_cast(PackType::ALL_W8A8)}; @@ -338,15 +366,14 @@ void NpuQwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { for (const auto& [index, name] : WEIGHT_MAPPING) { if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + set_weight(state_dict, name, index, WEIGHT_SHARD[index], true); } else { - set_weight(state_dict, name, index); + set_weight(state_dict, name, index, true); } } } int64_t NpuQwen2DecoderLayerImpl::init_layer() { - init_attn_mask(); name_ = "qwen2_decoder_layer"; model_name_ = "qwen2"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h index 17d2b15ac..1ca91dab8 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h @@ -108,7 +108,7 @@ class NpuQwen2DecoderLayerImpl : public NpuBaseLayer { public: explicit NpuQwen2DecoderLayerImpl(const ModelContext& context); - ~NpuQwen2DecoderLayerImpl() {}; + ~NpuQwen2DecoderLayerImpl() = default; TransposeType check_transpose(at::Tensor& tensor); @@ -120,6 +120,10 @@ class NpuQwen2DecoderLayerImpl : public NpuBaseLayer { virtual int64_t init_layer() override; + void move_device_and_init_layer(); + + void merge_and_move_pinned_host(); + torch::Tensor forward(torch::Tensor& x, torch::Tensor& cos_pos, torch::Tensor& sin_pos, @@ -139,6 +143,7 @@ class NpuQwen2DecoderLayerImpl : public NpuBaseLayer { KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill); + void merge_loaded_at_weights(); void param_from_args(atb_speed::qwen::DecoderLayerParam& param, const ModelArgs& args, diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index f7ae89231..5a3f69a30 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -281,6 +281,7 @@ NpuQwen3DecoderLayerImpl::NpuQwen3DecoderLayerImpl(const ModelContext& context) param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + at_host_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -299,27 +300,28 @@ NpuQwen3DecoderLayerImpl::NpuQwen3DecoderLayerImpl(const ModelContext& context) void NpuQwen3DecoderLayerImpl::verify_loaded_weights() const { for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + CHECK(at_host_weight_tensors_[index].sizes() != std::vector({1})) << "weight is not loaded for " << name; } } -void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { +void NpuQwen3DecoderLayerImpl::merge_loaded_at_weights() { if (quantize_type_.compare("w8a8") == 0) { - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); - at_weight_tensors_[IN_Q_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], - at_weight_tensors_[IN_K_DEQSCALE], - at_weight_tensors_[IN_V_DEQSCALE]}, + at_host_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = + at_host_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); + at_host_weight_tensors_[IN_Q_DEQSCALE] = + torch::cat({at_host_weight_tensors_[IN_Q_DEQSCALE], + at_host_weight_tensors_[IN_K_DEQSCALE], + at_host_weight_tensors_[IN_V_DEQSCALE]}, 0) .to(torch::kFloat32); - at_weight_tensors_[IN_Q_BIAS] = torch::cat({at_weight_tensors_[IN_Q_BIAS], - at_weight_tensors_[IN_K_BIAS], - at_weight_tensors_[IN_V_BIAS]}, - 0) - .to(torch::kInt32); + at_host_weight_tensors_[IN_Q_BIAS] = + torch::cat({at_host_weight_tensors_[IN_Q_BIAS], + at_host_weight_tensors_[IN_K_BIAS], + at_host_weight_tensors_[IN_V_BIAS]}, + 0) + .to(torch::kInt32); for (auto idx : {IN_K_DEQSCALE, IN_V_DEQSCALE, @@ -329,17 +331,17 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { IN_V_OFFSET, IN_K_SCALE, IN_V_SCALE}) { - at_weight_tensors_[idx] = at_placeholder_; + at_host_weight_tensors_[idx] = at_placeholder_; } - at_weight_tensors_[IN_MLP_W2_BIAS] = - torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], - at_weight_tensors_[IN_MLP_W1_BIAS]}, + at_host_weight_tensors_[IN_MLP_W2_BIAS] = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_BIAS], + at_host_weight_tensors_[IN_MLP_W1_BIAS]}, 0); - at_weight_tensors_[IN_MLP_W2_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], - at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, + at_host_weight_tensors_[IN_MLP_W2_DEQSCALE] = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_DEQSCALE], + at_host_weight_tensors_[IN_MLP_W1_DEQSCALE]}, 0) .to(torch::kFloat32); @@ -347,45 +349,46 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { IN_MLP_W1_OFFSET, IN_MLP_W1_SCALE, IN_MLP_W1_DEQSCALE}) { - at_weight_tensors_[idx] = at_placeholder_; + at_host_weight_tensors_[idx] = at_placeholder_; } - at_weight_tensors_[IN_Q_OFFSET] = - at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] + at_host_weight_tensors_[IN_Q_OFFSET] = + at_host_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); + at_host_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = + at_host_weight_tensors_[IN_ATTENTION_OUT_OFFSET] .to(torch::kInt8) .to(device_); - at_weight_tensors_[IN_MLP_W2_OFFSET] = - at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); + at_host_weight_tensors_[IN_MLP_W2_OFFSET] = + at_host_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); if (rank_id_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + torch::Tensor original_tensor = + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS]; auto shape = original_tensor.sizes(); auto dtype = original_tensor.dtype(); auto device = original_tensor.device(); - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( + at_host_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( shape, torch::TensorOptions().dtype(dtype).device(device)); } } - at_weight_tensors_[IN_Q_WEIGHT] = - torch::cat({at_weight_tensors_[IN_Q_WEIGHT], - at_weight_tensors_[IN_K_WEIGHT], - at_weight_tensors_[IN_V_WEIGHT]}, + at_host_weight_tensors_[IN_Q_WEIGHT] = + torch::cat({at_host_weight_tensors_[IN_Q_WEIGHT], + at_host_weight_tensors_[IN_K_WEIGHT], + at_host_weight_tensors_[IN_V_WEIGHT]}, 0) .contiguous(); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = - torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + at_host_weight_tensors_[IN_MLP_W2_WEIGHT] = + torch::cat({at_host_weight_tensors_[IN_MLP_W2_WEIGHT], + at_host_weight_tensors_[IN_MLP_W1_WEIGHT]}, 0) .contiguous(); for (auto idx : {IN_MLP_W1_WEIGHT, IN_K_WEIGHT, IN_V_WEIGHT, IN_K_BIAS, IN_V_BIAS}) { - at_weight_tensors_[idx] = at_placeholder_; + at_host_weight_tensors_[idx] = at_placeholder_; } if (prefill_param_.enableIntraLayerAddNorm || @@ -393,20 +396,23 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { if (quantize_type_.compare("w8a8") == 0) { // quantize torch::ScalarType weight_fill_dtype = torch::kBFloat16; - int64_t weight_attn_shape = at_weight_tensors_[IN_Q_WEIGHT].size(-1); - int64_t weight_mlp_shape = at_weight_tensors_[IN_MLP_W2_WEIGHT].size(-1); - at_weight_tensors_[IN_QKV_SCALE_FILL] = at_weight_tensors_[IN_Q_SCALE] - .repeat(weight_attn_shape) - .to(weight_fill_dtype); - at_weight_tensors_[IN_MLP_SCALE_FILL] = - at_weight_tensors_[IN_MLP_W2_SCALE] + int64_t weight_attn_shape = at_host_weight_tensors_[IN_Q_WEIGHT].size(-1); + int64_t weight_mlp_shape = + at_host_weight_tensors_[IN_MLP_W2_WEIGHT].size(-1); + at_host_weight_tensors_[IN_QKV_SCALE_FILL] = + at_host_weight_tensors_[IN_Q_SCALE] + .repeat(weight_attn_shape) + .to(weight_fill_dtype); + at_host_weight_tensors_[IN_MLP_SCALE_FILL] = + at_host_weight_tensors_[IN_MLP_W2_SCALE] .repeat(weight_mlp_shape) .to(weight_fill_dtype); - at_weight_tensors_[IN_QKV_OFFSET_FILL] = at_weight_tensors_[IN_Q_OFFSET] - .repeat(weight_attn_shape) - .to(weight_fill_dtype); - at_weight_tensors_[IN_MLP_OFFSET_FILL] = - at_weight_tensors_[IN_MLP_W2_OFFSET] + at_host_weight_tensors_[IN_QKV_OFFSET_FILL] = + at_host_weight_tensors_[IN_Q_OFFSET] + .repeat(weight_attn_shape) + .to(weight_fill_dtype); + at_host_weight_tensors_[IN_MLP_OFFSET_FILL] = + at_host_weight_tensors_[IN_MLP_W2_OFFSET] .repeat(weight_mlp_shape) .to(weight_fill_dtype); } else { @@ -415,52 +421,61 @@ void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { IN_QKV_OFFSET_FILL, IN_MLP_SCALE_FILL, IN_MLP_OFFSET_FILL}) { - at_weight_tensors_[idx] = at_placeholder_; + at_host_weight_tensors_[idx] = at_placeholder_; } } } - - c10_npu::NPUCachingAllocator::emptyCache(); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); - } - - init_layer(); } void NpuQwen3DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { if (quantize_type_.compare("w8a8") == 0) { for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); + set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index], true); } else { - set_weight(state_dict, name, index); + set_weight(state_dict, name, index, true); } } - at_weight_tensors_[IN_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_NORM_WEIGHT].options()) + at_host_weight_tensors_[IN_NORM_BIAS] = + torch::zeros(at_host_weight_tensors_[IN_NORM_WEIGHT].sizes(), + at_host_weight_tensors_[IN_NORM_WEIGHT].options()) .to(device_); - at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) + at_host_weight_tensors_[IN_SELFOUT_NORM_BIAS] = + torch::zeros(at_host_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), + at_host_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) .to(device_); return; } for (const auto& [index, name] : WEIGHT_MAPPING) { if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + set_weight(state_dict, name, index, WEIGHT_SHARD[index], true); } else { - set_weight(state_dict, name, index); + set_weight(state_dict, name, index, true); } } } -int64_t NpuQwen3DecoderLayerImpl::init_layer() { +void NpuQwen3DecoderLayerImpl::merge_loaded_weights() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_device(); + init_attn_mask(); + init_atb_tensors(); + init_layer(); +} + +void NpuQwen3DecoderLayerImpl::merge_and_move_pinned_host() { + merge_loaded_at_weights(); + init_weight_slices(WEIGHT_COUNT_PER_LAYER); + copy_weights_to_pinned_host(); init_attn_mask(); + init_atb_tensors(); + init_layer(); +} + +int64_t NpuQwen3DecoderLayerImpl::init_layer() { name_ = "qwen3_decoder_layer"; model_name_ = "qwen3"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index 785f43a16..91a351370 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -57,6 +57,8 @@ class NpuQwen3DecoderLayerImpl : public NpuBaseLayer { virtual int64_t init_layer() override; + void merge_and_move_pinned_host(); + torch::Tensor forward(torch::Tensor& x, torch::Tensor& cos_pos, torch::Tensor& sin_pos, @@ -68,6 +70,7 @@ class NpuQwen3DecoderLayerImpl : public NpuBaseLayer { int node_id = 0); private: + void merge_loaded_at_weights(); void param_from_args(atb_speed::qwen::QwenLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, diff --git a/xllm/models/CMakeLists.txt b/xllm/models/CMakeLists.txt index ed638c539..69ef139ff 100644 --- a/xllm/models/CMakeLists.txt +++ b/xllm/models/CMakeLists.txt @@ -7,8 +7,10 @@ cc_library( HDRS model_registry.h models.h + lazy_layer_loader.h SRCS model_registry.cpp + lazy_layer_loader.cpp DEPS :model ) diff --git a/xllm/models/lazy_layer_loader.cpp b/xllm/models/lazy_layer_loader.cpp new file mode 100644 index 000000000..33fe8bcfd --- /dev/null +++ b/xllm/models/lazy_layer_loader.cpp @@ -0,0 +1,90 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "lazy_layer_loader.h" + +#include + +namespace xllm { + +LazyLayerLoader::LazyLayerLoader(int32_t num_layers, int32_t device_id) + : num_layers_(num_layers), + device_id_(device_id), + load_stream_(c10_npu::getNPUStreamFromPool()), + events_(num_layers, nullptr), + event_recorded_(num_layers) { + uint32_t flags = ACL_EVENT_SYNC; + threadpool_ = std::make_unique(1); + + for (int32_t i = 0; i < num_layers; ++i) { + auto ret = aclrtCreateEventWithFlag(&events_[i], flags); + CHECK_EQ(ret, ACL_SUCCESS) << "Failed to create event for layer " << i; + event_recorded_[i].store(false, std::memory_order_relaxed); + } + + LOG(INFO) << "lazy layer loader initialized for " << num_layers << " layers"; +} + +LazyLayerLoader::~LazyLayerLoader() { + for (int i = 0; i < events_.size(); i++) { + if (events_[i] != nullptr) { + aclrtDestroyEvent(events_[i]); + } + } +} + +void LazyLayerLoader::start_async_loading(LayerLoader handle) { + LOG(INFO) << "starting asynchronous layer loading for " << num_layers_ + << " layers"; + threadpool_->schedule([this, handle]() { + for (int32_t i = 0; i < num_layers_; ++i) { + load_layer(i, handle); + } + }); +} + +void LazyLayerLoader::load_layer(int32_t layer_idx, LayerLoader handle) { + c10_npu::SetDevice(device_id_); + auto stream_guard = c10::StreamGuard(load_stream_.unwrap()); + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + + handle(layer_idx); + + // Capture current layer load task. + auto ret = aclrtRecordEvent(events_[layer_idx], load_stream_.stream()); + CHECK_EQ(ret, ACL_SUCCESS) + << "failed to record event for layer " << layer_idx; + + event_recorded_[layer_idx].store(true, std::memory_order_release); +} + +void LazyLayerLoader::wait_for_layer(int32_t layer_idx) { + while (!event_recorded_[layer_idx].load(std::memory_order_acquire)) { + // busy wait. + } + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + auto ret = aclrtStreamWaitEvent(stream, events_[layer_idx]); + CHECK_EQ(ret, ACL_SUCCESS) << "failed to sync layer " << layer_idx; + ret = aclrtResetEvent(events_[layer_idx], stream); + CHECK_EQ(ret, ACL_SUCCESS) << "failed to reset event " << layer_idx; +} + +void LazyLayerLoader::reset_events() { + for (int32_t i = 0; i < num_layers_; ++i) { + event_recorded_[i].store(false, std::memory_order_relaxed); + } +} +} // namespace xllm diff --git a/xllm/models/lazy_layer_loader.h b/xllm/models/lazy_layer_loader.h new file mode 100644 index 000000000..21576a2e2 --- /dev/null +++ b/xllm/models/lazy_layer_loader.h @@ -0,0 +1,90 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "core/platform/stream.h" +#include "core/util/threadpool.h" + +namespace xllm { + +/** + * LazyLayerLoader - Elegant abstraction for on-demand layer weight loading + * + * Design principles: + * - Defers layer weights loading until first forward pass + * - Sequential layer loading on dedicated stream + * - Per-layer ACL events for fine-grained synchronization + */ + +class LazyLayerLoader { + public: + /** + * Callback to load, verify, and merge a single layer + * @param layer_idx Index of the layer to process + */ + using LayerLoader = std::function; + + /** + * Constructor + * @param num_layers Total number of layers in the model + * @param device_id NPU device ID + */ + LazyLayerLoader(int32_t num_layers, int32_t device_id); + + ~LazyLayerLoader(); + + LazyLayerLoader(const LazyLayerLoader&) = delete; + LazyLayerLoader& operator=(const LazyLayerLoader&) = delete; + LazyLayerLoader(LazyLayerLoader&&) = delete; + LazyLayerLoader& operator=(LazyLayerLoader&&) = delete; + + /** + * reset all events to unrecorded state. + */ + void reset_events(); + /** + * Start asynchronous loading of all layers + * @param handle Callback that loads/verifies/merges a layer + */ + void start_async_loading(LayerLoader handle); + + /** + * Wait until specified layer is fully loaded and ready + * @param layer_idx Index of the layer to wait for + */ + void wait_for_layer(int32_t layer_idx); + + private: + void load_layer(int32_t layer_idx, LayerLoader processor); + + const int32_t num_layers_; + const int32_t device_id_; + + c10_npu::NPUStream load_stream_; + std::unique_ptr threadpool_; + std::vector events_; + std::vector> event_recorded_; +}; + +} // namespace xllm diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 441b52004..64f892ed8 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -35,6 +35,9 @@ limitations under the License. #include "core/layers/lm_head.h" #include "core/layers/pos_embedding.h" #include "core/layers/rms_norm.h" +#include "core/util/blocking_counter.h" +#include "core/util/threadpool.h" +#include "models/lazy_layer_loader.h" #include "models/model_registry.h" #if defined(USE_NPU) #include "xllm_kernels/core/include/atb_speed/log.h" @@ -125,6 +128,7 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { virtual void verify_loaded_weights(const std::string& prefix) const { decoder_layer_->verify_loaded_weights(); } + virtual void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); block_copy_->merge_loaded_weights(); @@ -162,6 +166,11 @@ class LlmModelImplBase : public torch::nn::Module { this->layer_forward_interrupted_ = interrupted; }); mrope_section_ = args.rope_scaling_mrope_section(); +#if defined(USE_NPU) + aclrtGetDevice(&device_id_); + threadpool_ = std::make_unique( + args.n_layers(), [this]() mutable { c10_npu::SetDevice(device_id_); }); +#endif } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { @@ -320,6 +329,7 @@ class LlmModelImplBase : public torch::nn::Module { layers_[i]->load_state_dict( state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } @@ -340,6 +350,7 @@ class LlmModelImplBase : public torch::nn::Module { for (int i = 0; i < layers_.size(); i++) { layers_[i]->merge_loaded_weights(); } + norm_->merge_loaded_weights(); } #endif @@ -375,7 +386,9 @@ class LlmModelImplBase : public torch::nn::Module { bool layer_forward_interrupted_ = false; private: + std::unique_ptr threadpool_; std::string model_type_; + int32_t device_id_; }; template