Skip to content

Commit 1be004c

Browse files
Merge pull request #203 from InfiniTensor/ascend-rope
feat: 添加昇腾rope算子
2 parents 5beab8c + 8727bdc commit 1be004c

File tree

10 files changed

+434
-67
lines changed

10 files changed

+434
-67
lines changed

src/infiniop/devices/ascend/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ include_directories(
2323
${CMAKE_SOURCE_DIR}/../../../../include/infiniop/
2424
)
2525

26-
2726
ascendc_library(ascend_kernels STATIC
2827
../../ops/swiglu/ascend/swiglu_ascend_kernel.cpp
29-
# ../../ops/rotary_embedding/ascend/rotary_embedding_kernel.cpp
28+
../../ops/rope/ascend/rope_ascend_kernel.cpp
3029
# ../../ops/random_sample/ascend/random_sample_kernel.cpp
3130
)
3231

src/infiniop/devices/ascend/ascend_kernel_common.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44
#include "../../../../include/infinicore.h"
55
#include "kernel_operator.h"
66

7-
constexpr int32_t BLOCK_NUM = 8;
8-
constexpr int32_t BUFFER_NUM = 2;
9-
constexpr int32_t BYTE_ALIGN = 32;
7+
constexpr size_t BLOCK_NUM = 8;
8+
constexpr size_t BUFFER_NUM = 2;
9+
constexpr size_t BYTE_ALIGN = 32;
10+
11+
template <typename T>
12+
__aicore__ inline size_t alignTileLen(size_t tile_len, size_t byte_align) {
13+
size_t bytes = tile_len * sizeof(T);
14+
size_t aligned_bytes = (bytes % byte_align == 0)
15+
? bytes
16+
: (bytes + (byte_align - bytes % byte_align));
17+
return aligned_bytes / sizeof(T);
18+
}
1019

1120
#endif
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "rope_ascend.h"
2+
#include "../../../devices/ascend/common_ascend.h"
3+
4+
namespace op::rope::ascend {
5+
6+
Descriptor::~Descriptor()
7+
= default;
8+
9+
infiniStatus_t Descriptor::create(
10+
infiniopHandle_t handle,
11+
Descriptor **desc_ptr,
12+
infiniopTensorDescriptor_t y_desc,
13+
infiniopTensorDescriptor_t x_desc,
14+
infiniopTensorDescriptor_t pos_desc,
15+
infiniopTensorDescriptor_t sin_desc,
16+
infiniopTensorDescriptor_t cos_desc) {
17+
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
18+
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
19+
CHECK_RESULT(result);
20+
21+
size_t workspace_size = 0;
22+
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
23+
return INFINI_STATUS_SUCCESS;
24+
}
25+
26+
infiniStatus_t Descriptor::calculate(
27+
void *workspace,
28+
size_t workspace_size,
29+
void *y,
30+
const void *x,
31+
const void *pos_ids,
32+
const void *sin_table,
33+
const void *cos_table,
34+
void *stream) const {
35+
CHECK_DTYPE(_info.data_type, INFINI_DTYPE_F32, INFINI_DTYPE_F16);
36+
37+
auto data_type = _info.data_type;
38+
auto pos_type = _info.pos_type;
39+
auto seq_len = _info.seqlen;
40+
auto nhead = _info.nhead;
41+
auto dhead = _info.dhead;
42+
43+
auto y_stride_seqlen = _info.y_stride_seqlen;
44+
auto y_stride_nhead = _info.y_stride_nhead;
45+
auto x_stride_seqlen = _info.x_stride_seqlen;
46+
auto x_stride_nhead = _info.x_stride_nhead;
47+
48+
return rope_kernel_launch(y, (void *)x, (void *)pos_ids, (void *)sin_table, (void *)cos_table, seq_len, nhead, dhead, data_type, pos_type, y_stride_seqlen, y_stride_nhead, x_stride_seqlen, x_stride_nhead, stream);
49+
}
50+
} // namespace op::rope::ascend
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef __ACLNN_ROPE_H__
2+
#define __ACLNN_ROPE_H__
3+
4+
#include "../rope.h"
5+
6+
extern "C" infiniStatus_t rope_kernel_launch(
7+
void *y,
8+
void *x,
9+
void *pos,
10+
void *sin,
11+
void *cos,
12+
size_t seq_len,
13+
size_t nhead,
14+
size_t dhead,
15+
infiniDtype_t data_type,
16+
infiniDtype_t pos_type,
17+
ptrdiff_t y_stride_seqlen,
18+
ptrdiff_t y_stride_nhead,
19+
ptrdiff_t x_stride_seqlen,
20+
ptrdiff_t x_stride_nhead,
21+
void *stream);
22+
23+
DESCRIPTOR(ascend)
24+
25+
#endif // __ACLNN_ROPE_H__
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#include "../../../devices/ascend/ascend_kernel_common.h"
2+
3+
using namespace AscendC;
4+
5+
template <typename T, typename U>
6+
class RoPEKernel {
7+
public:
8+
__aicore__ inline RoPEKernel() {}
9+
// Init op
10+
// pos position vector
11+
// x input tensor
12+
// y output tensor
13+
// tensor shape [nt, nh, dh]
14+
// make block_num = nh, tile_len = dh
15+
__aicore__ inline void init(GM_ADDR y,
16+
GM_ADDR x,
17+
GM_ADDR pos,
18+
GM_ADDR sin,
19+
GM_ADDR cos,
20+
size_t dh,
21+
ptrdiff_t st_ynt,
22+
ptrdiff_t st_ynh,
23+
ptrdiff_t st_xnt,
24+
ptrdiff_t st_xnh);
25+
__aicore__ inline void process(size_t seq_len);
26+
27+
private:
28+
// Copy a tile into UB
29+
__aicore__ inline void copyIn(size_t i);
30+
__aicore__ inline void compute(size_t i);
31+
__aicore__ inline void copyOut(size_t i);
32+
33+
private:
34+
TPipe pipe;
35+
TQue<QuePosition::VECIN, BUFFER_NUM> _in_que;
36+
TQue<QuePosition::VECIN, BUFFER_NUM> _sin_que;
37+
TQue<QuePosition::VECIN, BUFFER_NUM> _cos_que;
38+
TQue<QuePosition::VECOUT, BUFFER_NUM> _out_que;
39+
TBuf<TPosition::VECCALC> _tmp_odd_buf;
40+
TBuf<TPosition::VECCALC> _tmp_even_buf;
41+
TBuf<TPosition::VECCALC> _tmp_odd_buf1;
42+
TBuf<TPosition::VECCALC> _tmp_odd_buf2;
43+
TBuf<TPosition::VECCALC> _tmp_even_buf1;
44+
TBuf<TPosition::VECCALC> _tmp_even_buf2;
45+
46+
GlobalTensor<T> _x_gm, _y_gm;
47+
GlobalTensor<U> _p_gm;
48+
GlobalTensor<T> _sin_gm;
49+
GlobalTensor<T> _cos_gm;
50+
51+
size_t _block_idx;
52+
size_t _tile_len;
53+
size_t _copy_len;
54+
size_t _half_copy_len;
55+
56+
// stridey[_st_ynt, _st_ynh, 1]
57+
ptrdiff_t _st_ynt;
58+
ptrdiff_t _st_ynh;
59+
// stridex[_st_xnt, _st_xnh, 1]
60+
ptrdiff_t _st_xnt;
61+
ptrdiff_t _st_xnh;
62+
};
63+
64+
template <typename T, typename U>
65+
__aicore__ inline void RoPEKernel<T, U>::init(GM_ADDR y,
66+
GM_ADDR x,
67+
GM_ADDR pos,
68+
GM_ADDR sin,
69+
GM_ADDR cos,
70+
size_t dh,
71+
ptrdiff_t st_ynt,
72+
ptrdiff_t st_ynh,
73+
ptrdiff_t st_xnt,
74+
ptrdiff_t st_xnh) {
75+
this->_tile_len = dh;
76+
this->_st_ynt = st_ynt;
77+
this->_st_ynh = st_ynh;
78+
this->_st_xnt = st_xnt;
79+
this->_st_xnh = st_xnh;
80+
_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
81+
_half_copy_len = alignTileLen<T>(dh, BYTE_ALIGN);
82+
83+
_block_idx = GetBlockIdx();
84+
85+
// Init global buffer
86+
_x_gm.SetGlobalBuffer((__gm__ T *)x);
87+
_p_gm.SetGlobalBuffer((__gm__ U *)pos);
88+
_sin_gm.SetGlobalBuffer((__gm__ T *)sin);
89+
_cos_gm.SetGlobalBuffer((__gm__ T *)cos);
90+
_y_gm.SetGlobalBuffer((__gm__ T *)y);
91+
92+
// Init Queue buffer
93+
pipe.InitBuffer(_in_que, BUFFER_NUM, _copy_len * sizeof(T));
94+
pipe.InitBuffer(_out_que, BUFFER_NUM, _tile_len * sizeof(T));
95+
pipe.InitBuffer(_sin_que, BUFFER_NUM, _half_copy_len * sizeof(T));
96+
pipe.InitBuffer(_cos_que, BUFFER_NUM, _half_copy_len * sizeof(T));
97+
pipe.InitBuffer(_tmp_odd_buf, _tile_len / 2 * sizeof(T));
98+
pipe.InitBuffer(_tmp_even_buf, _tile_len / 2 * sizeof(T));
99+
pipe.InitBuffer(_tmp_odd_buf1, _tile_len / 2 * sizeof(T));
100+
pipe.InitBuffer(_tmp_odd_buf2, _tile_len / 2 * sizeof(T));
101+
pipe.InitBuffer(_tmp_even_buf1, _tile_len / 2 * sizeof(T));
102+
pipe.InitBuffer(_tmp_even_buf2, _tile_len / 2 * sizeof(T));
103+
}
104+
105+
template <typename T, typename U>
106+
__aicore__ inline void RoPEKernel<T, U>::copyIn(size_t i) {
107+
LocalTensor<T> input_ub = _in_que.AllocTensor<T>();
108+
LocalTensor<T> sin_ub = _sin_que.AllocTensor<T>();
109+
LocalTensor<T> cos_ub = _cos_que.AllocTensor<T>();
110+
// Get idx of current tile in total input
111+
auto idx = i * _st_xnt + _block_idx * _st_xnh;
112+
// Copy tile current tile into UB
113+
DataCopy(input_ub, _x_gm[idx], _copy_len);
114+
// Copy sin cos tile
115+
auto pos_idx = _p_gm(i);
116+
DataCopy(sin_ub, _sin_gm[pos_idx * _tile_len / 2], _half_copy_len);
117+
DataCopy(cos_ub, _cos_gm[pos_idx * _tile_len / 2], _half_copy_len);
118+
// Push in operands
119+
_in_que.EnQue(input_ub);
120+
_sin_que.EnQue(sin_ub);
121+
_cos_que.EnQue(cos_ub);
122+
}
123+
124+
template <typename T, typename U>
125+
__aicore__ inline void RoPEKernel<T, U>::compute(size_t i) {
126+
LocalTensor<T> input_ub = _in_que.DeQue<T>();
127+
LocalTensor<T> sin_ub = _sin_que.DeQue<T>();
128+
LocalTensor<T> cos_ub = _cos_que.DeQue<T>();
129+
LocalTensor<T> output_ub = _out_que.AllocTensor<T>();
130+
131+
LocalTensor<T> tmp_odd = _tmp_odd_buf.Get<T>();
132+
LocalTensor<T> tmp_even = _tmp_even_buf.Get<T>();
133+
LocalTensor<T> tmp_odd1 = _tmp_odd_buf1.Get<T>();
134+
LocalTensor<T> tmp_odd2 = _tmp_odd_buf2.Get<T>();
135+
LocalTensor<T> tmp_even1 = _tmp_even_buf1.Get<T>();
136+
LocalTensor<T> tmp_even2 = _tmp_even_buf2.Get<T>();
137+
138+
// separate odd and even bit elements
139+
uint64_t rsvdCnt = 0;
140+
GatherMaskParams gMaskParams = {
141+
1,
142+
static_cast<uint16_t>((_tile_len * sizeof(T) + 255) / 256), // no more than 256(<=255)
143+
8,
144+
8,
145+
};
146+
GatherMask<T>(tmp_odd, input_ub, 1, false, 0, gMaskParams, rsvdCnt);
147+
GatherMask<T>(tmp_even, input_ub, 2, false, 0, gMaskParams, rsvdCnt);
148+
PipeBarrier<PIPE_V>();
149+
150+
// compute odd bit elements
151+
// y_odd = x_odd * cos - x_even * sin
152+
Mul<T>(tmp_odd1, tmp_odd, cos_ub, _tile_len / 2);
153+
Mul<T>(tmp_odd2, tmp_even, sin_ub, _tile_len / 2);
154+
PipeBarrier<PIPE_V>();
155+
Sub<T>(tmp_odd1, tmp_odd1, tmp_odd2, _tile_len / 2);
156+
157+
// compute even bit elements
158+
// y_even = x_odd * sin + x_even * cos
159+
Mul<T>(tmp_even1, tmp_odd, sin_ub, _tile_len / 2);
160+
Mul<T>(tmp_even2, tmp_even, cos_ub, _tile_len / 2);
161+
PipeBarrier<PIPE_V>();
162+
Add<T>(tmp_even1, tmp_even1, tmp_even2, _tile_len / 2);
163+
164+
// combine odd and even bit elements
165+
for (uint32_t j = 0; j < _tile_len / 2; j += 1) {
166+
output_ub(j * 2) = tmp_odd1(j);
167+
output_ub(j * 2 + 1) = tmp_even1(j);
168+
}
169+
170+
_out_que.EnQue<T>(output_ub);
171+
_in_que.FreeTensor(input_ub);
172+
_sin_que.FreeTensor(sin_ub);
173+
_cos_que.FreeTensor(cos_ub);
174+
}
175+
176+
template <typename T, typename U>
177+
__aicore__ inline void RoPEKernel<T, U>::copyOut(size_t i) {
178+
LocalTensor<T> output_ub = _out_que.DeQue<T>();
179+
auto idy = i * _st_ynt + _block_idx * _st_ynh;
180+
DataCopyExtParams params = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0};
181+
DataCopyPad(_y_gm[idy], output_ub, params);
182+
_out_que.FreeTensor(output_ub);
183+
}
184+
185+
template <typename T, typename U>
186+
__aicore__ inline void RoPEKernel<T, U>::process(size_t seq_len) {
187+
188+
for (size_t i = 0; i < seq_len; ++i) {
189+
copyIn(i);
190+
compute(i);
191+
copyOut(i);
192+
}
193+
}
194+
195+
#define ROPE_KERNEL_INIT_ARGS y, x, pos, sin, cos, dhead, \
196+
y_stride_seqlen, y_stride_nhead, \
197+
x_stride_seqlen, x_stride_nhead
198+
199+
#define CASE_POSTYPE(POS_TYPE_ENUM, TYPE, POS_T) \
200+
case POS_TYPE_ENUM: { \
201+
RoPEKernel<TYPE, POS_T> op; \
202+
op.init(ROPE_KERNEL_INIT_ARGS); \
203+
op.process(seq_len); \
204+
break; \
205+
}
206+
207+
#define ROPE_KERNEL(TYPE, POSTYPE) \
208+
switch (POSTYPE) { \
209+
CASE_POSTYPE(INFINI_DTYPE_I8, TYPE, int8_t) \
210+
CASE_POSTYPE(INFINI_DTYPE_I16, TYPE, int16_t) \
211+
CASE_POSTYPE(INFINI_DTYPE_I32, TYPE, int32_t) \
212+
CASE_POSTYPE(INFINI_DTYPE_I64, TYPE, int64_t) \
213+
CASE_POSTYPE(INFINI_DTYPE_U8, TYPE, uint8_t) \
214+
CASE_POSTYPE(INFINI_DTYPE_U16, TYPE, uint16_t) \
215+
CASE_POSTYPE(INFINI_DTYPE_U32, TYPE, uint32_t) \
216+
CASE_POSTYPE(INFINI_DTYPE_U64, TYPE, uint64_t) \
217+
default: \
218+
break; \
219+
}
220+
221+
#define DEFINE_ROPE_KERNEL(KERNEL_NAME, TYPE) \
222+
__global__ __aicore__ void KERNEL_NAME(GM_ADDR y, \
223+
GM_ADDR x, \
224+
GM_ADDR pos, \
225+
GM_ADDR sin, \
226+
GM_ADDR cos, \
227+
size_t seq_len, \
228+
size_t dhead, \
229+
ptrdiff_t y_stride_seqlen, \
230+
ptrdiff_t y_stride_nhead, \
231+
ptrdiff_t x_stride_seqlen, \
232+
ptrdiff_t x_stride_nhead, \
233+
int32_t pos_type) { \
234+
ROPE_KERNEL(TYPE, pos_type) \
235+
}
236+
237+
DEFINE_ROPE_KERNEL(rope_kernel_float, float)
238+
DEFINE_ROPE_KERNEL(rope_kernel_half, half)
239+
240+
#undef DEFINE_ROPE_KERNEL
241+
#undef ROPE_KERNEL
242+
#undef CASE_POSTYPE
243+
#undef ROPE_KERNEL_INIT_ARGS
244+
245+
extern "C" infiniStatus_t rope_kernel_launch(
246+
void *y,
247+
void *x,
248+
void *pos,
249+
void *sin,
250+
void *cos,
251+
size_t seq_len,
252+
size_t nhead,
253+
size_t dhead,
254+
infiniDtype_t dtype,
255+
infiniDtype_t pos_type,
256+
ptrdiff_t y_stride_seqlen,
257+
ptrdiff_t y_stride_nhead,
258+
ptrdiff_t x_stride_seqlen,
259+
ptrdiff_t x_stride_nhead,
260+
void *stream) {
261+
262+
#define LAUNCH_ROPE_KERNEL(DTYPE_ENUM, KERNEL_NAME) \
263+
case DTYPE_ENUM: \
264+
KERNEL_NAME<<<nhead, nullptr, stream>>>(y, x, pos, sin, cos, \
265+
seq_len, \
266+
dhead, \
267+
y_stride_seqlen, \
268+
y_stride_nhead, \
269+
x_stride_seqlen, \
270+
x_stride_nhead, \
271+
pos_type); \
272+
return INFINI_STATUS_SUCCESS;
273+
274+
switch (dtype) {
275+
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F16, rope_kernel_half)
276+
LAUNCH_ROPE_KERNEL(INFINI_DTYPE_F32, rope_kernel_float)
277+
default:
278+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
279+
}
280+
}

0 commit comments

Comments
 (0)