Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xllm/core/framework/hf_model_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HFModelLoader::HFModelLoader(const std::string& model_weights_path)
<< "Failed to find model weights files in " << model_weights_path;
// sort the model weights files by name
std::sort(model_weights_files_.begin(), model_weights_files_.end());
threadpool_ = std::make_unique<ThreadPool>(32);
threadpool_ = std::make_unique<ThreadPool>(64);
}

std::unique_ptr<Tokenizer> HFModelLoader::tokenizer() const {
Expand Down
93 changes: 62 additions & 31 deletions xllm/core/layers/base_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,54 @@ void BaseLayer::correct_tensor_dtype(torch::Tensor& tensor,

void BaseLayer::set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position) {
int weight_position,
bool to_host) {
auto device = to_host ? at::kCPU : device_;
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
}

void BaseLayer::set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
if (parallel_args_.world_size() <= 1) {
int dim,
bool to_host) {
auto device = to_host ? at::kCPU : device_;
if (parallel_args_.world_size() <= 1) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
} else {
at_weight_tensors_[weight_position] =
state_dict
.get_sharded_tensor(tensor_name,
/*dim=*/dim,
/*rank=*/parallel_args_.rank(),
/*world_size=*/parallel_args_.world_size())
.to(device_);
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
} else {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = state_dict.get_sharded_tensor(
tensor_name,
/*dim=*/dim,
/*rank=*/parallel_args_.rank(),
/*world_size=*/parallel_args_.world_size());
correct_tensor_dtype(mutable_tensor, tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
}
Expand All @@ -124,26 +143,38 @@ void BaseLayer::set_weight(const StateDict& state_dict,
int weight_position,
int dim,
int rank,
int world_size) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
if (world_size <= 1) {
int world_size,
bool to_host) {
auto device = to_host ? at::kCPU : device_;
if (world_size <= 1) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
} else {
at_weight_tensors_[weight_position] =
state_dict
.get_sharded_tensor(tensor_name,
/*dim=*/dim,
/*rank=*/rank,
/*world_size=*/world_size)
.to(device_);
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
} else {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor =
state_dict.get_sharded_tensor(tensor_name,
/*dim=*/dim,
/*rank=*/rank,
/*world_size=*/world_size);
correct_tensor_dtype(mutable_tensor, tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
}
}

} // namespace layer
} // namespace xllm
10 changes: 7 additions & 3 deletions xllm/core/layers/base_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,21 @@ class BaseLayer : public torch::nn::Module {
void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim);
int dim,
bool to_host = false);

void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position);
int weight_position,
bool to_host = false);

void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int rank,
int world_size);
int world_size,
bool to_host = false);

virtual void run_task(std::string taskName, std::function<int()> task) const {
};
Expand All @@ -126,6 +129,7 @@ class BaseLayer : public torch::nn::Module {

protected:
std::vector<at::Tensor> at_weight_tensors_;
std::vector<at::Tensor> at_host_weight_tensors_;
at::Device device_;
std::string name_;
torch::ScalarType dtype_;
Expand Down
178 changes: 178 additions & 0 deletions xllm/core/layers/npu/npu_base_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ limitations under the License.
namespace xllm {
namespace layer {

namespace {
inline size_t AlignUp(size_t value, size_t alignment) {
return (value + alignment - 1) & ~(alignment - 1);
}
} // namespace

NpuBaseLayer::NpuBaseLayer(const ModelContext& context) : BaseLayer(context) {
context_ = const_cast<atb::Context*>(context.get_atb_context());
work_space_ = AtbWorkspace(device_);
}

NpuBaseLayer::~NpuBaseLayer() {
release_host_storage();
release_device_storage();
}

atb::Status NpuBaseLayer::execute_node(atb_speed::Model::Node& node,
int node_id,
aclrtEvent* event,
Expand Down Expand Up @@ -124,6 +135,173 @@ void NpuBaseLayer::run_task(std::string taskName,
cmd.Run();
}

void NpuBaseLayer::init_weight_slices(int weight_count) {
weight_slices_.resize(weight_count);
size_t offset = 0;
for (size_t i = 0; i < weight_count; ++i) {
weight_slices_[i] = {};
const auto& tensor = at_host_weight_tensors_[i];
if (!tensor.defined() || tensor.numel() < 1) {
continue;
}
// offset = AlignUp(offset, kHostAlignment);
weight_slices_[i].offset = offset;
weight_slices_[i].bytes = tensor.nbytes();
weight_slices_[i].sizes = tensor.sizes().vec();
weight_slices_[i].dtype = tensor.scalar_type();
offset += weight_slices_[i].bytes;
}
size_t max_alignment = std::max(kHostAlignment, kDeviceAlignment);
// storage_size_ = AlignUp(offset, max_alignment);
LOG(INFO) << "NpuBaseLayer total weight size: " << offset;
storage_size_ = offset;
}

void NpuBaseLayer::copy_weights_to_pinned_host() {
CHECK_GT(storage_size_, 0) << "model size must be greater than 0.";
CHECK_EQ(weight_slices_.size(), at_host_weight_tensors_.size())
<< "weight_slices_ size and at_host_weight_tensors_ size mismatch.";

size_t max_alignment = std::max(kHostAlignment, kDeviceAlignment);
storage_size_ = AlignUp(storage_size_, max_alignment);

auto ret = aclrtMallocHost(&host_pinned_storage_, storage_size_);
CHECK_EQ(ret, ACL_SUCCESS)
<< "Failed to allocate pinned host storage size=" << storage_size_;

for (size_t i = 0; i < weight_slices_.size(); ++i) {
const auto& slice = weight_slices_[i];
if (!slice.bytes) {
continue;
}
auto host_tensor = at_host_weight_tensors_[i].to(torch::kCPU).contiguous();
void* dst = static_cast<char*>(host_pinned_storage_) +
static_cast<ptrdiff_t>(slice.offset);
std::memcpy(dst, host_tensor.data_ptr(), slice.bytes);
at_host_weight_tensors_[i] = at::Tensor();
}

ret = aclrtMallocAlign32(
&device_storage_, storage_size_, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_EQ(ret, ACL_SUCCESS)
<< "Failed to allocate contiguous device storage size=" << storage_size_;
}

void NpuBaseLayer::copy_weights_to_device() {
CHECK_EQ(weight_slices_.size(), at_host_weight_tensors_.size())
<< "weight_slices_ size and at_host_weight_tensors_ size mismatch.";
auto ret = aclrtMallocAlign32(
&device_storage_, storage_size_, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_EQ(ret, ACL_SUCCESS)
<< "Failed to allocate contiguous device storage size=" << storage_size_;

for (size_t i = 0; i < weight_slices_.size(); ++i) {
const auto& slice = weight_slices_[i];
if (!slice.bytes) {
continue;
}
void* dst = static_cast<char*>(device_storage_) +
static_cast<ptrdiff_t>(slice.offset);
auto host_tensor = at_host_weight_tensors_[i].contiguous();
auto err = aclrtMemcpy(dst,
slice.bytes,
host_tensor.data_ptr(),
slice.bytes,
ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_EQ(err, ACL_SUCCESS) << "aclrtMemcpy failed for tensor index " << i;
at_host_weight_tensors_[i] = at::Tensor();
}
}

torch::Tensor NpuBaseLayer::convert_to_torch_tensor(
const std::vector<int64_t>& dims,
const torch::ScalarType dtype,
const uintptr_t& dev_addr,
int acl_format) {
c10::DeviceType device_type = c10::DeviceType::PrivateUse1;
torch::TensorOptions option =
torch::TensorOptions().dtype(dtype).device(device_type);

auto tensor = torch::empty({0}, option);
auto address = reinterpret_cast<void*>(dev_addr);
torch::DataPtr c10_data_ptr(address, address, [](void*) {}, tensor.device());

size_t tensor_nbytes = at::detail::computeStorageNbytesContiguous(
dims, tensor.dtype().itemsize());
torch::Storage storage;
// get npu storage constructor from register and construct storage
auto fptr = c10::GetStorageImplCreate(device_type);
auto allocator = c10::GetAllocator(device_type);
storage = fptr(c10::StorageImpl::use_byte_size_t(), 0, allocator, true);
storage.unsafeGetStorageImpl()->set_nbytes(tensor_nbytes);
storage.set_data_ptr(std::move(c10_data_ptr));

tensor.set_(storage, 0, dims);
// cast npu format to nd
tensor = at_npu::native::npu_format_cast(tensor, acl_format);

return tensor;
}

void NpuBaseLayer::init_atb_tensors() {
for (size_t i = 0; i < weight_slices_.size(); ++i) {
const auto& slice = weight_slices_[i];
if (!slice.bytes) {
continue;
}
void* base = static_cast<char*>(device_storage_) +
static_cast<ptrdiff_t>(slice.offset);
at_weight_tensors_[i] = convert_to_torch_tensor(
slice.sizes, slice.dtype, reinterpret_cast<uintptr_t>(base));
}

c10_npu::NPUCachingAllocator::emptyCache();

for (size_t i = 0; i < weight_slices_.size(); ++i) {
atb_weight_tensors_[i] =
atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]);
}
}

void NpuBaseLayer::copy_weights_to_device_async() {
CHECK_EQ(weight_slices_.size(), at_weight_tensors_.size())
<< "weight_slices_ size and at_weight_tensors_ size mismatch.";
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();

void* dst = static_cast<char*>(device_storage_);
void* src = static_cast<char*>(host_pinned_storage_);

auto err = aclrtMemcpyAsync(dst,
storage_size_,
src,
storage_size_,
ACL_MEMCPY_HOST_TO_DEVICE,
stream);
CHECK_EQ(err, ACL_SUCCESS) << "aclrtMemcpyAsync failed";
}

void NpuBaseLayer::release_device_storage() {
if (device_storage_ == nullptr) {
return;
}
auto ret = aclrtFree(device_storage_);
if (ret != ACL_SUCCESS) {
LOG(ERROR) << "Failed to free contiguous layer storage, ret=" << ret;
}
device_storage_ = nullptr;
}

void NpuBaseLayer::release_host_storage() {
if (host_pinned_storage_ == nullptr) {
return;
}
auto ret = aclrtFreeHost(host_pinned_storage_);
if (ret != ACL_SUCCESS) {
LOG(ERROR) << "Failed to free pinned host storage, ret=" << ret;
}
host_pinned_storage_ = nullptr;
}

atb::Tensor NpuBaseLayer::XTensor2Tensor(
const std::shared_ptr<xllm::XTensor>& xtensor) {
static std::map<at::ScalarType, aclDataType> dtypeMap = {
Expand Down
Loading
Loading