Skip to content

Commit b7ae53f

Browse files
authored
MultiheadAttention CUDA BF16 Support (microsoft#26083)
### Description MultiheadAttention CUDA BF16 Support ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent f2f50eb commit b7ae53f

File tree

8 files changed

+71
-22
lines changed

8 files changed

+71
-22
lines changed

docs/ContribOperators.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,9 +3264,9 @@ This version of the operator has been available since version 1 of the 'com.micr
32643264
#### Type Constraints
32653265

32663266
<dl>
3267-
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
3267+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
32683268
<dd>Constrain input and output to float tensors.</dd>
3269-
<dt><tt>QK</tt> : tensor(float), tensor(float16)</dt>
3269+
<dt><tt>QK</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
32703270
<dd>Constrain QK output to float32 or float16 tensors, independent of input type or output type.</dd>
32713271
<dt><tt>M</tt> : tensor(int32)</dt>
32723272
<dd>Constrain mask to integer types</dd>

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ Do not modify directly.*
992992
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
993993
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
994994
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
995-
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
995+
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
996996
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
997997
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
998998
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,16 @@ template Status QkvToContext<half, float>(
10891089
contrib::AttentionParameters& parameters,
10901090
AttentionData<half>& data);
10911091

1092+
template onnxruntime::common::Status
1093+
QkvToContext<float, BFloat16>(
1094+
const cudaDeviceProp&, cublasHandle_t&, cudnnHandle_t&,
1095+
Stream*, contrib::AttentionParameters&, AttentionData<float>&);
1096+
1097+
template onnxruntime::common::Status
1098+
QkvToContext<BFloat16, float>(
1099+
const cudaDeviceProp&, cublasHandle_t&, cudnnHandle_t&,
1100+
Stream*, contrib::AttentionParameters&, AttentionData<BFloat16>&);
1101+
10921102
template Status LaunchDecoderMaskedMultiHeadAttention<float, float>(
10931103
const DecoderMaskedMultiHeadAttentionParameters& parameters,
10941104
cudaStream_t stream,

onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
765765
// Template Instantiation
766766
template bool NoQkvWorkspace<float>(contrib::AttentionParameters& parameters, AttentionData<float>& data);
767767
template bool NoQkvWorkspace<half>(contrib::AttentionParameters& parameters, AttentionData<half>& data);
768+
template bool NoQkvWorkspace<BFloat16>(contrib::AttentionParameters& parameters, AttentionData<BFloat16>& data);
768769

769770
template Status PrepareQkv<float>(
770771
contrib::AttentionParameters& parameters,

onnxruntime/contrib_ops/cuda/bert/attention_qk.cu

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/cuda/cu_inc/common.cuh"
5+
#include "core/providers/cuda/cuda_type_conversion.h"
56
#include "contrib_ops/cuda/bert/attention_qk.h"
67

78
using namespace onnxruntime::cuda;
@@ -32,22 +33,38 @@ __global__ void ConvertAndCopyQK(const int count, const T* input, T* output) {
3233
}
3334
}
3435

36+
__global__ void ConvertAndCopyQK(const int count, const float* input, nv_bfloat16* output) {
37+
int idx = threadIdx.x + blockIdx.x * blockDim.x;
38+
if (idx < count) {
39+
output[idx] = __float2bfloat16(input[idx]);
40+
}
41+
}
42+
43+
__global__ void ConvertAndCopyQK(const int count, const nv_bfloat16* input, float* output) {
44+
int idx = threadIdx.x + blockIdx.x * blockDim.x;
45+
if (idx < count) {
46+
output[idx] = __bfloat162float(input[idx]);
47+
}
48+
}
49+
3550
template <typename T, typename QK>
36-
Status CopyQK(cudaStream_t stream,
37-
const int qk_size,
38-
const T* input,
39-
QK* output) {
40-
if constexpr (std::is_same_v<T, QK>) {
41-
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, static_cast<size_t>(qk_size) * sizeof(T), cudaMemcpyDeviceToDevice, stream));
51+
Status CopyQK(cudaStream_t stream, int qk_size, const T* input, QK* output) {
52+
using CudaT = typename OrtToCudaType<T>::type;
53+
using CudaQK = typename OrtToCudaType<QK>::type;
54+
55+
if constexpr (std::is_same_v<CudaT, CudaQK>) {
56+
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
57+
output, input, size_t(qk_size) * sizeof(T),
58+
cudaMemcpyDeviceToDevice, stream));
4259
return Status::OK();
4360
} else {
44-
constexpr const bool half2float = std::is_same<T, half>::value && std::is_same<QK, float>::value;
45-
constexpr const bool float2half = std::is_same<T, float>::value && std::is_same<QK, half>::value;
46-
static_assert(half2float || float2half, "This function supports either <float,half> or <half,float>");
61+
constexpr int block = 256;
62+
const int grid = (qk_size + block - 1) / block;
4763

48-
constexpr const int block_size = 256;
49-
int num_blocks = (qk_size + block_size - 1) / block_size;
50-
ConvertAndCopyQK<<<num_blocks, block_size, 0, stream>>>(qk_size, input, output);
64+
ConvertAndCopyQK<<<grid, block, 0, stream>>>(
65+
qk_size,
66+
reinterpret_cast<const CudaT*>(input),
67+
reinterpret_cast<CudaQK*>(output));
5168

5269
return CUDA_CALL(cudaGetLastError());
5370
}
@@ -63,6 +80,16 @@ template Status CopyQK<half, float>(cudaStream_t stream,
6380
const half* input,
6481
float* output);
6582

83+
template Status CopyQK<BFloat16, float>(cudaStream_t stream,
84+
const int qk_size,
85+
const BFloat16* input,
86+
float* output);
87+
88+
template Status CopyQK<float, BFloat16>(cudaStream_t stream,
89+
const int qk_size,
90+
const float* input,
91+
BFloat16* output);
92+
6693
template Status CopyQK(cudaStream_t stream,
6794
const int qk_size,
6895
const float* input,

onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ REGISTER_KERNEL_TYPED(float, float)
3838
REGISTER_KERNEL_TYPED(float, MLFloat16)
3939
REGISTER_KERNEL_TYPED(MLFloat16, float)
4040
REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16)
41+
REGISTER_KERNEL_TYPED(float, BFloat16)
42+
REGISTER_KERNEL_TYPED(BFloat16, float)
43+
REGISTER_KERNEL_TYPED(BFloat16, BFloat16)
4144

4245
template <typename T, typename QK>
4346
MultiHeadAttention<T, QK>::MultiHeadAttention(const OpKernelInfo& info)
@@ -56,20 +59,22 @@ MultiHeadAttention<T, QK>::MultiHeadAttention(const OpKernelInfo& info)
5659

5760
kernel_options_ = this->GetAttentionKernelOptions();
5861

59-
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
60-
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
62+
constexpr bool kIsFp16 = std::is_same<T, MLFloat16>::value;
6163

62-
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
64+
disable_fused_self_attention_ = !kIsFp16 || !kernel_options_->UseTrtFusedAttention();
65+
enable_trt_flash_attention_ = kIsFp16 && kernel_options_->UseTrtFlashAttention();
66+
67+
disable_flash_attention_ = !kIsFp16 || !kernel_options_->UseFlashAttention();
6368

6469
#if USE_LEAN_ATTENTION
6570
enable_lean_attention_ = sizeof(T) == 2 && kernel_options_->UseLeanAttention();
6671
#endif
6772

6873
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
6974

70-
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();
75+
disable_fused_cross_attention_ = !kIsFp16 || !kernel_options_->UseTrtCrossAttention();
7176

72-
enable_cudnn_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseCudnnFlashAttention();
77+
enable_cudnn_flash_attention_ = kIsFp16 && kernel_options_->UseCudnnFlashAttention();
7378

7479
disable_decoder_attention_ = !kernel_options_->UseDecoderAttention();
7580

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention);
104104
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention);
105105
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention);
106106
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention);
107+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_BFloat16, MultiHeadAttention);
108+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_float, MultiHeadAttention);
109+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_BFloat16, MultiHeadAttention);
107110
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
108111
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention);
109112
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention);
@@ -342,6 +345,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
342345
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention)>,
343346
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention)>,
344347
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention)>,
348+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_BFloat16, MultiHeadAttention)>,
349+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_float, MultiHeadAttention)>,
350+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_BFloat16, MultiHeadAttention)>,
345351
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention)>,
346352
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention)>,
347353
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention)>,

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
11131113
"normalized Q * K, of shape (batch_size, num_heads, sequence_length, total_sequence_length). ",
11141114
"QK",
11151115
OpSchema::Optional)
1116-
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
1117-
.TypeConstraint("QK", {"tensor(float)", "tensor(float16)"}, "Constrain QK output to float32 or float16 tensors, independent of input type or output type.")
1116+
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
1117+
.TypeConstraint("QK", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain QK output to float32 or float16 tensors, independent of input type or output type.")
11181118
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
11191119
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
11201120
MultiHeadAttentionTypeAndShapeInference(ctx, 6);

0 commit comments

Comments
 (0)