Skip to content

Commit 15e75af

Browse files
authored
Add RotaryEmbeddings(23) - CUDA (microsoft#25178)
Follow up microsoft#24980 Fix microsoft#24556 Add ONNX RotaryEmbedding(23) following https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding. The PR uses contrib op RotaryEmbedding implementation under the hood. The main difference between this op and the contrib op is that the position_ids in ONNX RotaryEmbedding is optional. When it's not provided, cos_cache and sin_cache should be 3d.
1 parent 8645dd5 commit 15e75af

File tree

7 files changed

+344
-6
lines changed

7 files changed

+344
-6
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ Do not modify directly.*
828828
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
829829
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
830830
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
831+
|RotaryEmbedding|*in* X:**T**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**M**<br> *out* Y:**T**|23+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
831832
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
832833
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
833834
|Scan|*in* initial_state_and_scan_inputs:**V**<br> *out* final_state_and_scan_outputs:**V**<br><br>or<br><br>*in* sequence_lens:**I**<br> *in* initial_state_and_scan_inputs:**V**<br> *out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cpu/llm/rotary_embedding.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
7272
const T* cos_data;
7373
const T* sin_data;
7474
int cache_offset;
75-
if (position_ids_format == 0) {
76-
cache_offset = (b * sequence_length + s) * half_rotary_emb_dim;
77-
} else {
78-
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
79-
const int position_id = static_cast<int>(position_ids[b * sequence_length + s]);
80-
cache_offset = position_id * half_rotary_emb_dim;
75+
// position_ids_format == 0 means position_ids is nullptr
76+
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
77+
int b_s_index = b * sequence_length + s;
78+
if (position_ids_format != 0) {
79+
b_s_index = static_cast<int>(position_ids[b_s_index]);
8180
}
81+
cache_offset = b_s_index * half_rotary_emb_dim;
8282
cos_data = cos_cache + cache_offset;
8383
sin_data = sin_cache + cache_offset;
8484

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
14901490
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization);
14911491
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization);
14921492
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization);
1493+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding);
1494+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding);
1495+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding);
14931496

14941497
#endif
14951498

@@ -2480,6 +2483,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
24802483
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization)>,
24812484
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization)>,
24822485
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization)>,
2486+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding)>,
2487+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding)>,
2488+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding)>,
24832489
#endif
24842490
};
24852491

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cuda/cuda_common.h"
5+
#include "core/providers/cpu/llm/rotary_embedding_helper.h"
6+
#include "core/providers/cuda/llm/rotary_embedding.h"
7+
#include "core/providers/cuda/llm/rotary_embedding_impl.h"
8+
9+
using namespace onnxruntime::cuda;
10+
using namespace ::onnxruntime::common;
11+
using namespace ONNX_NAMESPACE;
12+
using namespace onnxruntime::rotary_embedding_helper;
13+
14+
namespace onnxruntime {
15+
namespace cuda {
16+
17+
#define REGISTER_KERNEL_TYPED(T) \
18+
ONNX_OPERATOR_TYPED_KERNEL_EX( \
19+
RotaryEmbedding, \
20+
kOnnxDomain, \
21+
23, \
22+
T, \
23+
kCudaExecutionProvider, \
24+
(*KernelDefBuilder::Create()) \
25+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
26+
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \
27+
RotaryEmbedding<T>);
28+
29+
REGISTER_KERNEL_TYPED(float)
30+
REGISTER_KERNEL_TYPED(MLFloat16)
31+
REGISTER_KERNEL_TYPED(BFloat16)
32+
33+
template <typename T>
34+
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
35+
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
36+
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
37+
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
38+
}
39+
40+
template <typename T>
41+
Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
42+
const Tensor* input = context->Input<Tensor>(0);
43+
const Tensor* cos_cache = context->Input<Tensor>(1);
44+
const Tensor* sin_cache = context->Input<Tensor>(2);
45+
const Tensor* position_ids = context->Input<Tensor>(3); // Optional, can be nullptr
46+
47+
RotaryParameters parameters = {};
48+
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
49+
position_ids,
50+
cos_cache,
51+
sin_cache,
52+
num_heads,
53+
rotary_embedding_dim,
54+
&parameters));
55+
56+
Tensor* output = context->Output(0, input->Shape());
57+
58+
// Launch rotary embedding kernel
59+
typedef typename ToCudaType<T>::MappedType CudaT;
60+
auto& device_prop = GetDeviceProp();
61+
62+
// Handle optional position_ids - pass nullptr if position_ids is null
63+
const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data<int64_t>() : nullptr;
64+
65+
return LaunchRotaryEmbeddingKernel<CudaT>(
66+
Stream(context),
67+
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
68+
reinterpret_cast<const CudaT*>(input->template Data<T>()),
69+
position_ids_data,
70+
reinterpret_cast<const CudaT*>(cos_cache->template Data<T>()),
71+
reinterpret_cast<const CudaT*>(sin_cache->template Data<T>()),
72+
parameters.batch_size,
73+
parameters.sequence_length,
74+
parameters.num_heads,
75+
parameters.head_size,
76+
parameters.rotary_embedding_dim,
77+
parameters.max_sequence_length,
78+
parameters.position_ids_format,
79+
interleaved,
80+
device_prop.maxThreadsPerBlock,
81+
parameters.transposed);
82+
}
83+
84+
} // namespace cuda
85+
} // namespace onnxruntime
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/common.h"
6+
#include "core/providers/cuda/cuda_kernel.h"
7+
8+
namespace onnxruntime {
9+
namespace cuda {
10+
11+
using namespace onnxruntime::cuda;
12+
13+
template <typename T>
14+
class RotaryEmbedding final : public CudaKernel {
15+
public:
16+
RotaryEmbedding(const OpKernelInfo& info);
17+
Status ComputeInternal(OpKernelContext* context) const override;
18+
19+
protected:
20+
int num_heads;
21+
int rotary_embedding_dim;
22+
int interleaved;
23+
};
24+
25+
} // namespace cuda
26+
} // namespace onnxruntime
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
Copyright (c) Microsoft Corporation.
3+
Licensed under the MIT License.
4+
*/
5+
6+
/*
7+
Kernel implementation for rotary embeddings.
8+
*/
9+
10+
#include "core/providers/cuda/llm/rotary_embedding_impl.h"
11+
#include "core/providers/cuda/cu_inc/common.cuh"
12+
#include <cuda_fp16.h>
13+
14+
using namespace onnxruntime::cuda;
15+
16+
namespace onnxruntime {
17+
namespace cuda {
18+
19+
template <typename T>
20+
__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
21+
const T* input, // BxSxNxH
22+
const T* cos_cache, // BxSx(H/2) or Mx(H/2)
23+
const T* sin_cache, // BxSx(H/2) or Mx(H/2)
24+
const int64_t* position_ids, // (0) or BxS
25+
const int sequence_length, const int num_heads, const int head_size,
26+
const int rotary_embedding_dim, const int position_ids_format,
27+
const bool interleaved,
28+
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
29+
) {
30+
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
31+
// Use .x in innermost loop to access global memory efficiently
32+
33+
const int b = blockIdx.y;
34+
const int s = blockIdx.x;
35+
const int n = blockIdx.z;
36+
37+
const int i = threadIdx.x;
38+
39+
if (i >= head_size) {
40+
return;
41+
}
42+
43+
const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y;
44+
T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y;
45+
46+
if (i >= rotary_embedding_dim) {
47+
output_data[i] = input_data[i];
48+
return;
49+
}
50+
51+
// Cache is (M, H/2)
52+
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
53+
int cache_offset;
54+
55+
// position_ids_format == 0 means position_ids is nullptr
56+
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
57+
int b_s_index = b * sequence_length + s;
58+
if (position_ids_format != 0) {
59+
b_s_index = static_cast<int>(position_ids[b_s_index]);
60+
}
61+
cache_offset = b_s_index * half_rotary_embedding_dim;
62+
const T* cos_data = cos_cache + cache_offset;
63+
const T* sin_data = sin_cache + cache_offset;
64+
65+
int cache_idx = 0;
66+
T sign = 0;
67+
int j = 0;
68+
if (interleaved) {
69+
cache_idx = (i / 2) % half_rotary_embedding_dim;
70+
sign = (i % 2 == 0) ? -1 : 1;
71+
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
72+
} else {
73+
cache_idx = i % half_rotary_embedding_dim;
74+
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
75+
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
76+
}
77+
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
78+
}
79+
80+
template <typename T>
81+
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
82+
const T* cos_cache, const T* sin_cache, const int batch_size,
83+
const int sequence_length, const int num_heads, const int head_size,
84+
const int rotary_embedding_dim, const int max_sequence_length,
85+
const int position_ids_format, const bool interleaved,
86+
const int max_threads_per_block, const bool is_input_bnsh_format) {
87+
int4 in_strides;
88+
int4 out_strides;
89+
if (is_input_bnsh_format) {
90+
// Semantic meaning of the strides:
91+
// int in_head_stride = sequence_length * head_size;
92+
// int out_head_stride = sequence_length * head_size;
93+
// in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1};
94+
// out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1};
95+
// Simplify to:
96+
in_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1};
97+
out_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1};
98+
} else {
99+
// input is in bshn format
100+
// int in_head_stride = head_size;
101+
// int out_head_stride = head_size;
102+
// Simplify to:
103+
in_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1};
104+
out_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1};
105+
}
106+
return LaunchRotaryEmbeddingKernel<T>(
107+
stream, output, input, position_ids,
108+
cos_cache, sin_cache, batch_size,
109+
sequence_length, num_heads, head_size,
110+
rotary_embedding_dim, max_sequence_length,
111+
position_ids_format, interleaved,
112+
max_threads_per_block,
113+
in_strides, out_strides);
114+
}
115+
116+
template <typename T>
117+
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
118+
const T* cos_cache, const T* sin_cache, const int batch_size,
119+
const int sequence_length, const int num_heads, const int head_size,
120+
const int rotary_embedding_dim, const int /*max_sequence_length*/,
121+
const int position_ids_format, const bool interleaved,
122+
const int max_threads_per_block,
123+
int4 in_strides, int4 out_strides // strides in bnsh coord
124+
) {
125+
// Note: Current implementation assumes head_size <= max_threads_per_block
126+
// because head_size is currently large for LLaMA-2. For smaller head_size
127+
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
128+
// instead. This will require kernel changes to support.
129+
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
130+
// strides in canonical bnsh coord, h is always contiguous (dim_stride == 1)
131+
ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous");
132+
133+
int tpb = (head_size + 31) / 32 * 32;
134+
135+
const dim3 block(tpb);
136+
const dim3 grid(sequence_length, batch_size, num_heads);
137+
138+
assert(head_size <= max_threads_per_block);
139+
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
140+
num_heads, head_size, rotary_embedding_dim, position_ids_format,
141+
interleaved, in_strides, out_strides);
142+
return CUDA_CALL(cudaGetLastError());
143+
}
144+
145+
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input,
146+
const int64_t* position_ids, const float* cos_cache,
147+
const float* sin_cache, const int batch_size,
148+
const int sequence_length, const int num_heads, const int head_size,
149+
const int rotary_embedding_dim, const int max_sequence_length,
150+
const int position_ids_format, const bool interleaved,
151+
const int max_threads_per_block, const bool is_input_bnsh_format);
152+
153+
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input,
154+
const int64_t* position_ids, const half* cos_cache,
155+
const half* sin_cache, const int batch_size,
156+
const int sequence_length, const int num_heads, const int head_size,
157+
const int rotary_embedding_dim, const int max_sequence_length,
158+
const int position_ids_format, const bool interleaved,
159+
const int max_threads_per_block, const bool is_input_bnsh_format);
160+
161+
template Status LaunchRotaryEmbeddingKernel<BFloat16>(
162+
cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids,
163+
const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length,
164+
const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length,
165+
const int position_ids_format, const bool interleaved, const int max_threads_per_block,
166+
const bool is_input_bnsh_format);
167+
168+
} // namespace cuda
169+
} // namespace onnxruntime
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/common.h"
6+
#include "core/providers/cuda/shared_inc/cuda_utils.h"
7+
8+
namespace onnxruntime {
9+
namespace cuda {
10+
11+
template <typename T>
12+
Status LaunchRotaryEmbeddingKernel(
13+
cudaStream_t stream,
14+
T* output,
15+
const T* input,
16+
const int64_t* position_ids,
17+
const T* cos_cache,
18+
const T* sin_cache,
19+
const int batch_size,
20+
const int sequence_length,
21+
const int num_heads,
22+
const int head_size,
23+
const int rotary_embedding_dim,
24+
const int max_sequence_length,
25+
const int position_ids_format,
26+
const bool interleaved,
27+
const int max_threads_per_block,
28+
const bool is_input_bnsh_format);
29+
30+
template <typename T>
31+
Status LaunchRotaryEmbeddingKernel(
32+
cudaStream_t stream,
33+
T* output,
34+
const T* input,
35+
const int64_t* position_ids,
36+
const T* cos_cache,
37+
const T* sin_cache,
38+
const int batch_size,
39+
const int sequence_length,
40+
const int num_heads,
41+
const int head_size,
42+
const int rotary_embedding_dim,
43+
const int max_sequence_length,
44+
const int position_ids_format,
45+
const bool interleaved,
46+
const int max_threads_per_block,
47+
int4 in_strides,
48+
int4 out_strides);
49+
50+
} // namespace cuda
51+
} // namespace onnxruntime

0 commit comments

Comments
 (0)