1010#include " core/framework/tensorprotoutils.h"
1111#include " core/graph/onnx_protobuf.h"
1212#include " core/common/safeint.h"
13+ #include " core/platform/env_var_utils.h"
1314#include " core/platform/threadpool.h"
15+ #include " core/mlas/inc/mlas.h"
1416
17+ #include < algorithm>
18+ #include < type_traits>
1519#include < unsupported/Eigen/SpecialFunctions>
1620#include < vector>
1721
@@ -39,6 +43,11 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
3943
4044 mask_filter_value_ = info.GetAttrOrDefault <float >(" mask_filter_value" , -10000 .0f );
4145 is_unidirectional_ = info.GetAttrOrDefault <int64_t >(" unidirectional" , 0 ) == 1 ;
46+
47+ const auto & env = Env::Default ();
48+ l2_cache_size_ = env.GetL2CacheSize ();
49+
50+ disable_flash_ = ParseEnvironmentVariableWithDefault<bool >(attention::kDisableFlashAttention , false );
4251}
4352
4453template <typename T>
@@ -60,7 +69,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
6069 }
6170
6271 AttentionParameters parameters = {};
63- constexpr float scale = 1 .0f ;
6472 bool past_present_share_buffer = false ;
6573 ORT_RETURN_IF_ERROR (multihead_attention_helper::CheckInputs<Tensor>(query,
6674 key,
@@ -74,7 +82,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
7482 ¶meters,
7583 num_heads_,
7684 mask_filter_value_,
77- scale ,
85+ scale_ ,
7886 is_unidirectional_,
7987 past_present_share_buffer,
8088 false ));
@@ -99,8 +107,14 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
99107 const int v_bias_offset = 2 * qk_hidden_size;
100108
101109 // If optional outputs aren't needed, present_k and present_v will be null
102- std::vector<int64_t > present_k_shape ({static_cast <int64_t >(batch_size), static_cast <int64_t >(num_heads_), static_cast <int64_t >(total_kv_sequence_length), static_cast <int64_t >(qk_head_size)});
103- std::vector<int64_t > present_v_shape ({static_cast <int64_t >(batch_size), static_cast <int64_t >(num_heads_), static_cast <int64_t >(total_kv_sequence_length), static_cast <int64_t >(v_head_size)});
110+ std::vector<int64_t > present_k_shape ({static_cast <int64_t >(batch_size),
111+ static_cast <int64_t >(num_heads_),
112+ static_cast <int64_t >(total_kv_sequence_length),
113+ static_cast <int64_t >(qk_head_size)});
114+ std::vector<int64_t > present_v_shape ({static_cast <int64_t >(batch_size),
115+ static_cast <int64_t >(num_heads_),
116+ static_cast <int64_t >(total_kv_sequence_length),
117+ static_cast <int64_t >(v_head_size)});
104118 Tensor* present_k = context->Output (1 , present_k_shape);
105119 Tensor* present_v = context->Output (2 , present_v_shape);
106120
@@ -138,6 +152,70 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
138152 ORT_RETURN_IF_ERROR (MaybeTransposeToBNSHAndAddBias<T>(
139153 context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));
140154
155+ if (std::is_same_v<T, float > &&
156+ !disable_flash_ &&
157+ !is_unidirectional_ &&
158+ key_padding_mask == nullptr &&
159+ extra_add_qk == nullptr &&
160+ past_key == nullptr &&
161+ past_value == nullptr &&
162+ present_k == nullptr &&
163+ present_v == nullptr &&
164+ l2_cache_size_ > 0 ) {
165+ MlasFlashAttentionThreadedArgs args;
166+ args.batch_size = batch_size;
167+ args.num_heads = num_heads_;
168+ args.q_sequence_length = q_sequence_length;
169+ args.kv_sequence_length = kv_sequence_length;
170+ args.qk_head_size = qk_head_size;
171+ args.v_head_size = v_head_size;
172+ args.scale = (scale_ == 0 .0f ) ? 1 .0f / sqrt (static_cast <float >(qk_head_size)) : scale_;
173+ /*
174+ q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper.
175+ Let M = l2_cache_size / sizeof(float)
176+ In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache:
177+ slice of Q -- [Br, qk_head_size]
178+ slice of K -- [Bc, qk_head_size]
179+ slice of V -- [Bc, v_head_size]
180+ result of QK -- [Br, Bc]
181+ temporary output (same shape as QKV) -- [Br, v_head_size]
182+ The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
183+ By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have
184+ (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
185+ <= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc
186+ <= 2 * Bc * (qk_head_size + v_head_size) + M/4
187+ <= 2 * M/4 + M/4 = M * (3/4)
188+
189+ We leave 1/4 of the L2 cache for
190+ 1. storing small tensors l and m
191+ 2. instruction (code)
192+ */
193+ args.kv_block_size = l2_cache_size_ / (static_cast <int >(sizeof (float )) * 4 * (qk_head_size + v_head_size));
194+ args.kv_block_size = std::max (args.kv_block_size , 1 ); // avoid kv_block_size = 0
195+ args.q_block_size = std::min (args.kv_block_size , qk_head_size + v_head_size);
196+ args.kv_block_size = std::min (args.kv_block_size , kv_sequence_length); // No point to have kv_block_size > kv_sequence_length
197+ args.q_block_size = std::min (args.q_block_size , q_sequence_length); // No point to have q_block_size > q_sequence_length
198+
199+ auto * tp = context->GetOperatorThreadPool ();
200+ args.thread_count = concurrency::ThreadPool::DegreeOfParallelism (tp);
201+ args.buffer_size_per_thread = (static_cast <size_t >(args.q_block_size ) * 2 +
202+ static_cast <size_t >(args.q_block_size ) * static_cast <size_t >(args.kv_block_size ) +
203+ static_cast <size_t >(args.q_block_size ) * static_cast <size_t >(args.v_head_size )) *
204+ sizeof (float );
205+ size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count ;
206+ IAllocatorUniquePtr<void > buffer = IAllocator::MakeUniquePtr<void >(allocator, buffer_bytes);
207+
208+ args.buffer = reinterpret_cast <float *>(buffer.get ());
209+
210+ args.query = Q.Get <Tensor>().Data <float >();
211+ args.key = K.Get <Tensor>().Data <float >();
212+ args.value = V.Get <Tensor>().Data <float >();
213+ args.output = output->MutableData <float >();
214+
215+ MlasFlashAttention (&args, tp);
216+ return Status::OK ();
217+ }
218+
141219 // Compute the attention score and apply the score to V
142220 return ApplyAttention (Q.GetMutable <Tensor>()->MutableData <T>(),
143221 K.GetMutable <Tensor>()->MutableData <T>(),
0 commit comments