|
| 1 | +/* |
| 2 | +Copyright (c) Microsoft Corporation. |
| 3 | +Licensed under the MIT License. |
| 4 | +*/ |
| 5 | + |
| 6 | +/* |
| 7 | +Kernel implementation for rotary embeddings. |
| 8 | +*/ |
| 9 | + |
| 10 | +#include "core/providers/cuda/llm/rotary_embedding_impl.h" |
| 11 | +#include "core/providers/cuda/cu_inc/common.cuh" |
| 12 | +#include <cuda_fp16.h> |
| 13 | + |
| 14 | +using namespace onnxruntime::cuda; |
| 15 | + |
| 16 | +namespace onnxruntime { |
| 17 | +namespace cuda { |
| 18 | + |
| 19 | +template <typename T> |
| 20 | +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH |
| 21 | + const T* input, // BxSxNxH |
| 22 | + const T* cos_cache, // BxSx(H/2) or Mx(H/2) |
| 23 | + const T* sin_cache, // BxSx(H/2) or Mx(H/2) |
| 24 | + const int64_t* position_ids, // (0) or BxS |
| 25 | + const int sequence_length, const int num_heads, const int head_size, |
| 26 | + const int rotary_embedding_dim, const int position_ids_format, |
| 27 | + const bool interleaved, |
| 28 | + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous |
| 29 | +) { |
| 30 | + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length |
| 31 | + // Use .x in innermost loop to access global memory efficiently |
| 32 | + |
| 33 | + const int b = blockIdx.y; |
| 34 | + const int s = blockIdx.x; |
| 35 | + const int n = blockIdx.z; |
| 36 | + |
| 37 | + const int i = threadIdx.x; |
| 38 | + |
| 39 | + if (i >= head_size) { |
| 40 | + return; |
| 41 | + } |
| 42 | + |
| 43 | + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; |
| 44 | + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; |
| 45 | + |
| 46 | + if (i >= rotary_embedding_dim) { |
| 47 | + output_data[i] = input_data[i]; |
| 48 | + return; |
| 49 | + } |
| 50 | + |
| 51 | + // Cache is (M, H/2) |
| 52 | + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; |
| 53 | + int cache_offset; |
| 54 | + |
| 55 | + // position_ids_format == 0 means position_ids is nullptr |
| 56 | + // position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) |
| 57 | + int b_s_index = b * sequence_length + s; |
| 58 | + if (position_ids_format != 0) { |
| 59 | + b_s_index = static_cast<int>(position_ids[b_s_index]); |
| 60 | + } |
| 61 | + cache_offset = b_s_index * half_rotary_embedding_dim; |
| 62 | + const T* cos_data = cos_cache + cache_offset; |
| 63 | + const T* sin_data = sin_cache + cache_offset; |
| 64 | + |
| 65 | + int cache_idx = 0; |
| 66 | + T sign = 0; |
| 67 | + int j = 0; |
| 68 | + if (interleaved) { |
| 69 | + cache_idx = (i / 2) % half_rotary_embedding_dim; |
| 70 | + sign = (i % 2 == 0) ? -1 : 1; |
| 71 | + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign |
| 72 | + } else { |
| 73 | + cache_idx = i % half_rotary_embedding_dim; |
| 74 | + sign = (i < half_rotary_embedding_dim) ? -1 : 1; |
| 75 | + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; |
| 76 | + } |
| 77 | + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; |
| 78 | +} |
| 79 | + |
| 80 | +template <typename T> |
| 81 | +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, |
| 82 | + const T* cos_cache, const T* sin_cache, const int batch_size, |
| 83 | + const int sequence_length, const int num_heads, const int head_size, |
| 84 | + const int rotary_embedding_dim, const int max_sequence_length, |
| 85 | + const int position_ids_format, const bool interleaved, |
| 86 | + const int max_threads_per_block, const bool is_input_bnsh_format) { |
| 87 | + int4 in_strides; |
| 88 | + int4 out_strides; |
| 89 | + if (is_input_bnsh_format) { |
| 90 | + // Semantic meaning of the strides: |
| 91 | + // int in_head_stride = sequence_length * head_size; |
| 92 | + // int out_head_stride = sequence_length * head_size; |
| 93 | + // in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; |
| 94 | + // out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; |
| 95 | + // Simplify to: |
| 96 | + in_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; |
| 97 | + out_strides = int4{num_heads * sequence_length * head_size, sequence_length * head_size, head_size, 1}; |
| 98 | + } else { |
| 99 | + // input is in bshn format |
| 100 | + // int in_head_stride = head_size; |
| 101 | + // int out_head_stride = head_size; |
| 102 | + // Simplify to: |
| 103 | + in_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; |
| 104 | + out_strides = int4{num_heads * sequence_length * head_size, head_size, num_heads * head_size, 1}; |
| 105 | + } |
| 106 | + return LaunchRotaryEmbeddingKernel<T>( |
| 107 | + stream, output, input, position_ids, |
| 108 | + cos_cache, sin_cache, batch_size, |
| 109 | + sequence_length, num_heads, head_size, |
| 110 | + rotary_embedding_dim, max_sequence_length, |
| 111 | + position_ids_format, interleaved, |
| 112 | + max_threads_per_block, |
| 113 | + in_strides, out_strides); |
| 114 | +} |
| 115 | + |
| 116 | +template <typename T> |
| 117 | +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, |
| 118 | + const T* cos_cache, const T* sin_cache, const int batch_size, |
| 119 | + const int sequence_length, const int num_heads, const int head_size, |
| 120 | + const int rotary_embedding_dim, const int /*max_sequence_length*/, |
| 121 | + const int position_ids_format, const bool interleaved, |
| 122 | + const int max_threads_per_block, |
| 123 | + int4 in_strides, int4 out_strides // strides in bnsh coord |
| 124 | +) { |
| 125 | + // Note: Current implementation assumes head_size <= max_threads_per_block |
| 126 | + // because head_size is currently large for LLaMA-2. For smaller head_size |
| 127 | + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` |
| 128 | + // instead. This will require kernel changes to support. |
| 129 | + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); |
| 130 | + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) |
| 131 | + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); |
| 132 | + |
| 133 | + int tpb = (head_size + 31) / 32 * 32; |
| 134 | + |
| 135 | + const dim3 block(tpb); |
| 136 | + const dim3 grid(sequence_length, batch_size, num_heads); |
| 137 | + |
| 138 | + assert(head_size <= max_threads_per_block); |
| 139 | + RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, |
| 140 | + num_heads, head_size, rotary_embedding_dim, position_ids_format, |
| 141 | + interleaved, in_strides, out_strides); |
| 142 | + return CUDA_CALL(cudaGetLastError()); |
| 143 | +} |
| 144 | + |
| 145 | +template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input, |
| 146 | + const int64_t* position_ids, const float* cos_cache, |
| 147 | + const float* sin_cache, const int batch_size, |
| 148 | + const int sequence_length, const int num_heads, const int head_size, |
| 149 | + const int rotary_embedding_dim, const int max_sequence_length, |
| 150 | + const int position_ids_format, const bool interleaved, |
| 151 | + const int max_threads_per_block, const bool is_input_bnsh_format); |
| 152 | + |
| 153 | +template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input, |
| 154 | + const int64_t* position_ids, const half* cos_cache, |
| 155 | + const half* sin_cache, const int batch_size, |
| 156 | + const int sequence_length, const int num_heads, const int head_size, |
| 157 | + const int rotary_embedding_dim, const int max_sequence_length, |
| 158 | + const int position_ids_format, const bool interleaved, |
| 159 | + const int max_threads_per_block, const bool is_input_bnsh_format); |
| 160 | + |
| 161 | +template Status LaunchRotaryEmbeddingKernel<BFloat16>( |
| 162 | + cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, |
| 163 | + const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length, |
| 164 | + const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, |
| 165 | + const int position_ids_format, const bool interleaved, const int max_threads_per_block, |
| 166 | + const bool is_input_bnsh_format); |
| 167 | + |
| 168 | +} // namespace cuda |
| 169 | +} // namespace onnxruntime |
0 commit comments