Skip to content

Commit 0c80339

Browse files
Merge pull request #139 from InfiniTensor/issue/9
feat: 添加昇腾swiglu算子
2 parents bd37042 + fafb22d commit 0c80339

File tree

10 files changed

+329
-21
lines changed

10 files changed

+329
-21
lines changed

src/infiniop/devices/ascend/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.16.0)
33
# project information
44
project(Ascend_C)
55
set(SOC_VERSION "Ascend910B3" CACHE STRING "system on chip type")
6-
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME} CACHE PATH "ASCEND CANN package installation directory")
6+
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_TOOLKIT_HOME} CACHE PATH "ASCEND CANN package installation directory")
77
set(RUN_MODE "npu" CACHE STRING "run mode: npu")
88
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type Release/Debug (default Debug)" FORCE)
99
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRING "path for install()" FORCE)
@@ -19,10 +19,14 @@ else()
1919
endif()
2020

2121
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
22+
include_directories(
23+
${CMAKE_SOURCE_DIR}/../../../../include/infiniop/
24+
)
25+
2226

2327
ascendc_library(ascend_kernels STATIC
24-
../../ops/swiglu/ascend/swiglu_kernel.cpp
25-
../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
26-
../../ops/random_sample/ascend/random_sample_kernel.cpp
28+
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
29+
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
30+
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
2731
)
2832

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef __INFINIOP_ASCEND_KERNEL_COMMON_H__
2+
#define __INFINIOP_ASCEND_KERNEL_COMMON_H__
3+
4+
#include "../../../../include/infinicore.h"
5+
#include "kernel_operator.h"
6+
7+
constexpr int32_t BLOCK_NUM = 8;
8+
constexpr int32_t BUFFER_NUM = 2;
9+
constexpr int32_t BYTE_ALIGN = 32;
10+
11+
#endif

src/infiniop/ops/causal_softmax/ascend/causal_softmax_aclnn.cc renamed to src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "causal_softmax_aclnn.h"
1+
#include "causal_softmax_ascend.h"
22
#include "../../../devices/ascend/common_ascend.h"
33
#include <aclnnop/aclnn_masked_fill_tensor.h>
44
#include <aclnnop/aclnn_softmax.h>

src/infiniop/ops/causal_softmax/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "cuda/causal_softmax_cuda.cuh"
1010
#endif
1111
#ifdef ENABLE_ASCEND_API
12-
#include "ascend/causal_softmax_aclnn.h"
12+
#include "ascend/causal_softmax_ascend.h"
1313
#endif
1414

1515
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "swiglu_ascend.h"
2+
#include "../../../devices/ascend/common_ascend.h"
3+
4+
namespace op::swiglu::ascend {
5+
Descriptor::~Descriptor() = default;
6+
7+
infiniStatus_t Descriptor::create(infiniopHandle_t handle, Descriptor **desc_ptr,
8+
infiniopTensorDescriptor_t c_desc,
9+
std::vector<infiniopTensorDescriptor_t> input_descs) {
10+
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
11+
12+
auto dtype = c_desc->dtype();
13+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
14+
15+
const auto &a_desc = input_descs[0];
16+
const auto &b_desc = input_descs[1];
17+
18+
auto result = SwigluInfo::create(c_desc, a_desc, b_desc);
19+
CHECK_RESULT(result);
20+
SwigluInfo info = result.take();
21+
22+
// https://www.hiascend.com/document/detail/zh/canncommercial/800/apiref/ascendcopapi/atlasascendc_api_07_0777.html
23+
size_t workspace_size = 0;
24+
25+
*desc_ptr = new Descriptor(std::move(info), workspace_size, handle_ascend->device, handle_ascend->device_id);
26+
return INFINI_STATUS_SUCCESS;
27+
}
28+
29+
extern "C" infiniStatus_t swiglu_kernel_launch(
30+
void *c, void *a, void *b,
31+
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
32+
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
33+
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream);
34+
35+
infiniStatus_t Descriptor::calculate(void *workspace,
36+
size_t workspace_size,
37+
void *c,
38+
std::vector<const void *> inputs,
39+
void *stream) const {
40+
auto batch = _info.ndim == 2 ? 1 : _info.shape[0];
41+
auto seq_len = _info.ndim == 2 ? _info.shape[0] : _info.shape[1];
42+
auto hidden_size = _info.shape[_info.ndim - 1];
43+
auto stride_batch_c = _info.ndim == 2 ? 1 : _info.c_strides[0];
44+
auto stride_batch_a = _info.ndim == 2 ? 1 : _info.a_strides[0];
45+
auto stride_batch_b = _info.ndim == 2 ? 1 : _info.b_strides[0];
46+
auto stride_seq_c = _info.ndim == 2 ? _info.c_strides[0] : _info.c_strides[1];
47+
auto stride_seq_a = _info.ndim == 2 ? _info.a_strides[0] : _info.a_strides[1];
48+
auto stride_seq_b = _info.ndim == 2 ? _info.b_strides[0] : _info.b_strides[1];
49+
auto status = swiglu_kernel_launch(c, (void *)inputs[0], (void *)inputs[1], _info.dtype, batch, seq_len, hidden_size, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b, stream);
50+
return status;
51+
}
52+
53+
} // namespace op::swiglu::ascend
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#ifndef __ACLNN_SWIGLU_H__
2+
#define __ACLNN_SWIGLU_H__
3+
4+
#include "../../../../utils.h"
5+
#include "../../../../utils/check.h"
6+
#include "../../../operator.h"
7+
#include "../../../tensor.h"
8+
9+
namespace op::swiglu::ascend {
10+
class SwigluInfo {
11+
12+
SwigluInfo() = default;
13+
14+
public:
15+
infiniDtype_t dtype;
16+
std::vector<size_t> shape;
17+
int32_t ndim;
18+
std::vector<ptrdiff_t> c_strides;
19+
std::vector<ptrdiff_t> a_strides;
20+
std::vector<ptrdiff_t> b_strides;
21+
22+
static utils::Result<SwigluInfo> create(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) {
23+
CHECK_OR_RETURN(c_desc && a_desc && b_desc, INFINI_STATUS_BAD_PARAM);
24+
CHECK_OR_RETURN(!c_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES);
25+
CHECK_OR_RETURN(c_desc->ndim() == a_desc->ndim()
26+
&& c_desc->ndim() == b_desc->ndim()
27+
&& (c_desc->ndim() == 2 || c_desc->ndim() == 3),
28+
INFINI_STATUS_BAD_TENSOR_SHAPE);
29+
CHECK_SAME_SHAPE(c_desc->shape(), a_desc->shape(), b_desc->shape());
30+
int32_t ndim = c_desc->ndim();
31+
CHECK_OR_RETURN(c_desc->stride(ndim - 1) == 1
32+
&& a_desc->stride(ndim - 1) == 1
33+
&& b_desc->stride(ndim - 1) == 1,
34+
INFINI_STATUS_BAD_TENSOR_STRIDES);
35+
CHECK_OR_RETURN(c_desc->dtype() == a_desc->dtype()
36+
&& c_desc->dtype() == b_desc->dtype(),
37+
INFINI_STATUS_BAD_TENSOR_DTYPE);
38+
39+
return utils::Result<SwigluInfo>(SwigluInfo{
40+
c_desc->dtype(),
41+
c_desc->shape(),
42+
ndim,
43+
c_desc->strides(),
44+
a_desc->strides(),
45+
b_desc->strides(),
46+
});
47+
}
48+
};
49+
50+
class Descriptor final : public InfiniopDescriptor {
51+
SwigluInfo _info;
52+
size_t _workspace_size;
53+
54+
Descriptor(SwigluInfo info, size_t workspace_size, infiniDevice_t device_type, int device_id) : InfiniopDescriptor{device_type, device_id},
55+
_info(info), _workspace_size(workspace_size) {}
56+
57+
public:
58+
~Descriptor();
59+
static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr,
60+
infiniopTensorDescriptor_t c_desc,
61+
std::vector<infiniopTensorDescriptor_t> input_descs);
62+
size_t workspaceSize() const { return _workspace_size; }
63+
64+
infiniStatus_t calculate(
65+
void *workspace,
66+
size_t workspace_size,
67+
void *c,
68+
std::vector<const void *> inputs,
69+
void *stream) const;
70+
};
71+
72+
} // namespace op::swiglu::ascend
73+
#endif // __ACLNN_SWIGLU_H__
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#include "../../../devices/ascend/ascend_kernel_common.h"
2+
3+
using namespace AscendC;
4+
5+
template <typename T>
6+
class SwigluKernel {
7+
public:
8+
__aicore__ inline SwigluKernel() {}
9+
__aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
10+
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
11+
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b);
12+
__aicore__ inline void process();
13+
14+
private:
15+
__aicore__ inline void copyIn(int64_t i);
16+
__aicore__ inline void compute(int64_t i);
17+
__aicore__ inline void copyOut(int64_t i);
18+
19+
private:
20+
GlobalTensor<T> _c_gm, _a_gm, _b_gm;
21+
TQue<QuePosition::VECIN, BUFFER_NUM> _in_queue_a, _in_queue_b;
22+
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_queue_c;
23+
24+
TPipe _pipe;
25+
float _beta_value = 1.0f;
26+
int64_t _block_idx, _tile_len, _copy_len,
27+
_batch, _seq_len, _hidden_size,
28+
_stride_seq_a, _stride_seq_b, _stride_seq_c;
29+
int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1;
30+
};
31+
32+
template <typename T>
33+
__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd,
34+
int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b,
35+
int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) {
36+
// Init Shape & StrideVariables
37+
_batch = batch_;
38+
_seq_len = seq;
39+
_hidden_size = hd;
40+
_stride_batch_a = stride_batch_a;
41+
_stride_batch_b = stride_batch_b;
42+
_stride_batch_c = stride_batch_c;
43+
_stride_seq_a = stride_seq_a;
44+
_stride_seq_b = stride_seq_b;
45+
_stride_seq_c = stride_seq_c;
46+
47+
_block_idx = GetBlockIdx();
48+
_tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM);
49+
_copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T);
50+
51+
// Set global tensor
52+
_a_gm.SetGlobalBuffer((__gm__ T *)a);
53+
_b_gm.SetGlobalBuffer((__gm__ T *)b);
54+
_c_gm.SetGlobalBuffer((__gm__ T *)c);
55+
56+
// _pipe alloc memory to queue, the unit is bytes
57+
_pipe.InitBuffer(_in_queue_a, BUFFER_NUM, _copy_len * sizeof(T));
58+
_pipe.InitBuffer(_in_queue_b, BUFFER_NUM, _copy_len * sizeof(T));
59+
_pipe.InitBuffer(_out_queue_c, BUFFER_NUM, _copy_len * sizeof(T));
60+
}
61+
62+
template <typename T>
63+
__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) {
64+
// Alloc tensor from queue memory
65+
LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>();
66+
LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>();
67+
// Get idx of current tile
68+
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
69+
auto seq_idx = _batch == 1 ? i : i % _seq_len;
70+
71+
int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len;
72+
int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len;
73+
// Copy process_th tile from global tensor to local tensor
74+
DataCopy(aLocal, _a_gm[idxa], _copy_len);
75+
DataCopy(bLocal, _b_gm[idxb], _copy_len);
76+
77+
// Enque input tensor to VECIN queue
78+
_in_queue_a.EnQue(aLocal);
79+
_in_queue_b.EnQue(bLocal);
80+
}
81+
82+
template <typename T>
83+
__aicore__ inline void SwigluKernel<T>::compute(int64_t i) {
84+
// Deque input tensors from VECIN queue
85+
LocalTensor<T> aLocal = _in_queue_a.DeQue<T>();
86+
LocalTensor<T> bLocal = _in_queue_b.DeQue<T>();
87+
LocalTensor<T> cLocal = _out_queue_c.AllocTensor<T>();
88+
// Call SwiGLU ascend api
89+
SwiGLU<T, false>(cLocal, aLocal, bLocal, _beta_value, _copy_len);
90+
// Enque result and free input
91+
_out_queue_c.EnQue<T>(cLocal);
92+
_in_queue_a.FreeTensor(aLocal);
93+
_in_queue_b.FreeTensor(bLocal);
94+
}
95+
96+
template <typename T>
97+
__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) {
98+
// Deque output tensor from VECOUT queue
99+
LocalTensor<T> cLocal = _out_queue_c.DeQue<T>();
100+
auto batch_idx = _batch == 1 ? 0 : i / _seq_len;
101+
auto seq_idx = _batch == 1 ? i : i % _seq_len;
102+
int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len;
103+
// Copy progress_th tile from local tensor to global tensor
104+
if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) {
105+
DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
106+
DataCopyPad(_c_gm[idxc], cLocal, dcep);
107+
} else {
108+
DataCopy(_c_gm[idxc], cLocal, _tile_len);
109+
}
110+
// Free output Local tensor
111+
_out_queue_c.FreeTensor(cLocal);
112+
}
113+
114+
template <typename T>
115+
__aicore__ inline void SwigluKernel<T>::process() {
116+
for (int64_t i = 0; i < _batch * _seq_len; ++i) {
117+
copyIn(i);
118+
compute(i);
119+
copyOut(i);
120+
}
121+
}
122+
123+
#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \
124+
__global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \
125+
int64_t batch, int64_t seq, int64_t hd, \
126+
int64_t stride_batch_c, \
127+
int64_t stride_batch_a, \
128+
int64_t stride_batch_b, \
129+
int64_t stride_seq_c, \
130+
int64_t stride_seq_a, \
131+
int64_t stride_seq_b) { \
132+
SwigluKernel<TYPE> op; \
133+
op.init(c, a, b, \
134+
batch, seq, hd, \
135+
stride_batch_c, stride_batch_a, stride_batch_b, \
136+
stride_seq_c, stride_seq_a, stride_seq_b); \
137+
op.process(); \
138+
}
139+
140+
DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half)
141+
DEFINE_SWIGLU_KERNEL(swiglu_kernel_float, float)
142+
143+
#undef DEFINE_SWIGLU_KERNEL
144+
145+
extern "C" infiniStatus_t swiglu_kernel_launch(
146+
void *c, void *a, void *b,
147+
infiniDtype_t dtype, size_t batch, size_t seq, size_t hd,
148+
ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b,
149+
ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream) {
150+
151+
#define LAUNCH_SWIGLU_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
152+
case DTYPE_ENUM: \
153+
KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \
154+
c, a, b, \
155+
static_cast<int64_t>(batch), \
156+
static_cast<int64_t>(seq), \
157+
static_cast<int64_t>(hd), \
158+
stride_batch_c, stride_batch_a, stride_batch_b, \
159+
stride_seq_c, stride_seq_a, stride_seq_b); \
160+
return INFINI_STATUS_SUCCESS;
161+
162+
switch (dtype) {
163+
LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F16, swiglu_kernel_half)
164+
LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F32, swiglu_kernel_float)
165+
default:
166+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
167+
}
168+
169+
#undef LAUNCH_SWIGLU_KERNEL
170+
}

0 commit comments

Comments
 (0)