Skip to content

Commit 80b56fe

Browse files
duanqntianleiwuQingnan Duan
authored
Implement FlashAttention for CPU (microsoft#20805)
### Description Implement [FlashAttention](https://arxiv.org/pdf/2205.14135) and [FlashAttention-2](https://arxiv.org/pdf/2307.08691) for MultiHeadAttention on CPU. ### Motivation and Context Accelerate the execution of MultiHeadAttention. Current performance: 10ms vs 16ms (com.microsoft.MultiHeadAttention) on my Linux machine and 10ms vs 38ms (com.microsoft.MultiHeadAttention) on my Windows machine. May need further optimizations. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: Qingnan Duan <qiduan@microsoft.com>
1 parent 33e7c7f commit 80b56fe

File tree

10 files changed

+363
-5
lines changed

10 files changed

+363
-5
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
3939
${MLAS_SRC_DIR}/sqnbitgemm.h
4040
${MLAS_SRC_DIR}/sqnbitgemm.cpp
4141
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
42+
${MLAS_SRC_DIR}/flashattn.cpp
4243
)
4344

4445
target_sources(onnxruntime_mlas PRIVATE

onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
#include "core/framework/tensorprotoutils.h"
1111
#include "core/graph/onnx_protobuf.h"
1212
#include "core/common/safeint.h"
13+
#include "core/platform/env_var_utils.h"
1314
#include "core/platform/threadpool.h"
15+
#include "core/mlas/inc/mlas.h"
1416

17+
#include <algorithm>
18+
#include <type_traits>
1519
#include <unsupported/Eigen/SpecialFunctions>
1620
#include <vector>
1721

@@ -39,6 +43,11 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
3943

4044
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
4145
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
46+
47+
const auto& env = Env::Default();
48+
l2_cache_size_ = env.GetL2CacheSize();
49+
50+
disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
4251
}
4352

4453
template <typename T>
@@ -60,7 +69,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
6069
}
6170

6271
AttentionParameters parameters = {};
63-
constexpr float scale = 1.0f;
6472
bool past_present_share_buffer = false;
6573
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
6674
key,
@@ -74,7 +82,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
7482
&parameters,
7583
num_heads_,
7684
mask_filter_value_,
77-
scale,
85+
scale_,
7886
is_unidirectional_,
7987
past_present_share_buffer,
8088
false));
@@ -99,8 +107,14 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
99107
const int v_bias_offset = 2 * qk_hidden_size;
100108

101109
// If optional outputs aren't needed, present_k and present_v will be null
102-
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(qk_head_size)});
103-
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(v_head_size)});
110+
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
111+
static_cast<int64_t>(num_heads_),
112+
static_cast<int64_t>(total_kv_sequence_length),
113+
static_cast<int64_t>(qk_head_size)});
114+
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size),
115+
static_cast<int64_t>(num_heads_),
116+
static_cast<int64_t>(total_kv_sequence_length),
117+
static_cast<int64_t>(v_head_size)});
104118
Tensor* present_k = context->Output(1, present_k_shape);
105119
Tensor* present_v = context->Output(2, present_v_shape);
106120

@@ -138,6 +152,70 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
138152
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
139153
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));
140154

155+
if (std::is_same_v<T, float> &&
156+
!disable_flash_ &&
157+
!is_unidirectional_ &&
158+
key_padding_mask == nullptr &&
159+
extra_add_qk == nullptr &&
160+
past_key == nullptr &&
161+
past_value == nullptr &&
162+
present_k == nullptr &&
163+
present_v == nullptr &&
164+
l2_cache_size_ > 0) {
165+
MlasFlashAttentionThreadedArgs args;
166+
args.batch_size = batch_size;
167+
args.num_heads = num_heads_;
168+
args.q_sequence_length = q_sequence_length;
169+
args.kv_sequence_length = kv_sequence_length;
170+
args.qk_head_size = qk_head_size;
171+
args.v_head_size = v_head_size;
172+
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;
173+
/*
174+
q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper.
175+
Let M = l2_cache_size / sizeof(float)
176+
In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache:
177+
slice of Q -- [Br, qk_head_size]
178+
slice of K -- [Bc, qk_head_size]
179+
slice of V -- [Bc, v_head_size]
180+
result of QK -- [Br, Bc]
181+
temporary output (same shape as QKV) -- [Br, v_head_size]
182+
The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
183+
By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have
184+
(Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
185+
<= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc
186+
<= 2 * Bc * (qk_head_size + v_head_size) + M/4
187+
<= 2 * M/4 + M/4 = M * (3/4)
188+
189+
We leave 1/4 of the L2 cache for
190+
1. storing small tensors l and m
191+
2. instruction (code)
192+
*/
193+
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
194+
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0
195+
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
196+
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length
197+
args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length
198+
199+
auto* tp = context->GetOperatorThreadPool();
200+
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);
201+
args.buffer_size_per_thread = (static_cast<size_t>(args.q_block_size) * 2 +
202+
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.kv_block_size) +
203+
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.v_head_size)) *
204+
sizeof(float);
205+
size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count;
206+
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(allocator, buffer_bytes);
207+
208+
args.buffer = reinterpret_cast<float*>(buffer.get());
209+
210+
args.query = Q.Get<Tensor>().Data<float>();
211+
args.key = K.Get<Tensor>().Data<float>();
212+
args.value = V.Get<Tensor>().Data<float>();
213+
args.output = output->MutableData<float>();
214+
215+
MlasFlashAttention(&args, tp);
216+
return Status::OK();
217+
}
218+
141219
// Compute the attention score and apply the score to V
142220
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
143221
K.GetMutable<Tensor>()->MutableData<T>(),

onnxruntime/contrib_ops/cpu/bert/multihead_attention.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
1919
int num_heads_; // number of attention heads
2020
float mask_filter_value_;
2121
bool is_unidirectional_;
22+
bool disable_flash_;
23+
int l2_cache_size_;
2224
};
2325

2426
} // namespace contrib

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,3 +1825,35 @@ MlasNhwcAvgPool(
18251825
);
18261826

18271827
#endif
1828+
1829+
struct MlasFlashAttentionThreadedArgs {
1830+
int batch_size;
1831+
int num_heads;
1832+
int q_sequence_length;
1833+
int kv_sequence_length;
1834+
int qk_head_size;
1835+
int v_head_size;
1836+
int q_block_size;
1837+
int kv_block_size;
1838+
float scale;
1839+
int thread_count;
1840+
float* buffer;
1841+
size_t buffer_size_per_thread;
1842+
const float* query;
1843+
const float* key;
1844+
const float* value;
1845+
float* output;
1846+
};
1847+
1848+
/**
1849+
* @brief Per-thread worker function for fp32 Flash Attention
1850+
* @param thread_id Thread index
1851+
* @param args Arguments
1852+
* @return
1853+
*/
1854+
void
1855+
MLASCALL
1856+
MlasFlashAttention(
1857+
MlasFlashAttentionThreadedArgs* args,
1858+
MLAS_THREADPOOL* ThreadPool
1859+
);
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#include <numeric>
2+
3+
#include "mlasi.h"
4+
5+
void
6+
MlasFlashAttentionThreaded(
7+
void* argptr,
8+
std::ptrdiff_t thread_id
9+
)
10+
{
11+
const MlasFlashAttentionThreadedArgs* args = reinterpret_cast<MlasFlashAttentionThreadedArgs*>(argptr);
12+
ptrdiff_t q_block_size = static_cast<ptrdiff_t>(args->q_block_size);
13+
ptrdiff_t kv_block_size = static_cast<ptrdiff_t>(args->kv_block_size);
14+
ptrdiff_t batch_size = static_cast<ptrdiff_t>(args->batch_size);
15+
ptrdiff_t num_heads = static_cast<ptrdiff_t>(args->num_heads);
16+
ptrdiff_t q_sequence_length = static_cast<ptrdiff_t>(args->q_sequence_length);
17+
ptrdiff_t kv_sequence_length = static_cast<ptrdiff_t>(args->kv_sequence_length);
18+
ptrdiff_t qk_head_size = static_cast<ptrdiff_t>(args->qk_head_size);
19+
ptrdiff_t v_head_size = static_cast<ptrdiff_t>(args->v_head_size);
20+
float* buffer = args->buffer;
21+
ptrdiff_t buffer_size_per_thread = static_cast<ptrdiff_t>(args->buffer_size_per_thread);
22+
ptrdiff_t thread_count = static_cast<ptrdiff_t>(args->thread_count);
23+
const float* query = args->query;
24+
const float* key = args->key;
25+
const float* value = args->value;
26+
float* output = args->output;
27+
28+
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
29+
auto&& mlas_platform = GetMlasPlatform();
30+
#endif
31+
32+
ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size;
33+
34+
ptrdiff_t task_start = 0;
35+
ptrdiff_t task_end = 0;
36+
ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count;
37+
ptrdiff_t quotient = total_task_count / thread_count;
38+
ptrdiff_t remainder = total_task_count % thread_count;
39+
if (thread_id < remainder) {
40+
task_start = (quotient + 1) * thread_id;
41+
task_end = task_start + quotient + 1;
42+
} else {
43+
task_start = quotient * thread_id + remainder;
44+
task_end = task_start + quotient;
45+
}
46+
47+
for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) {
48+
ptrdiff_t batch_idx = task_index;
49+
ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size;
50+
batch_idx /= q_chunk_count;
51+
ptrdiff_t head_idx = batch_idx % num_heads;
52+
batch_idx /= num_heads;
53+
54+
char* buffer_current_thread = reinterpret_cast<char*>(buffer) + thread_id * buffer_size_per_thread;
55+
float* l = reinterpret_cast<float*>(buffer_current_thread);
56+
float* m = l + q_block_size;
57+
for (ptrdiff_t t = 0; t < q_block_size; ++t) {
58+
m[t] = std::numeric_limits<float>::lowest();
59+
}
60+
float* intermediate = m + q_block_size;
61+
float* temp_output = intermediate + q_block_size * kv_block_size;
62+
float negmax = 0;
63+
64+
for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) {
65+
/*
66+
S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T
67+
old_m = m
68+
m = max(m, rowmax(S))
69+
diff = old_m - m
70+
S = exp(S - m)
71+
l = exp(diff) * l + rowsum(S)
72+
O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :]
73+
*/
74+
ptrdiff_t h = batch_idx * num_heads + head_idx;
75+
const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size;
76+
const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size;
77+
const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size;
78+
79+
size_t row_size_q_capped = static_cast<size_t>(std::min(q_block_size, q_sequence_length - q_idx));
80+
size_t row_size_kv_capped = static_cast<size_t>(std::min(kv_block_size, kv_sequence_length - ir));
81+
82+
MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
83+
CBLAS_TRANSPOSE::CblasTrans,
84+
row_size_q_capped,
85+
row_size_kv_capped,
86+
static_cast<size_t>(qk_head_size),
87+
args->scale,
88+
inputQ,
89+
static_cast<size_t>(qk_head_size),
90+
inputK,
91+
static_cast<size_t>(qk_head_size),
92+
0.0f,
93+
intermediate,
94+
row_size_kv_capped);
95+
96+
for (ptrdiff_t irow = 0; irow < static_cast<ptrdiff_t>(row_size_q_capped); ++irow) {
97+
float* p = intermediate + irow * row_size_kv_capped;
98+
99+
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
100+
float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped);
101+
#else
102+
float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped);
103+
#endif
104+
float m_diff = m[irow];
105+
m[irow] = std::max(m[irow], rowmax); // new m
106+
negmax = -m[irow];
107+
m_diff -= m[irow]; // old - new (less than 0)
108+
109+
#if defined(MLAS_TARGET_AMD64)
110+
float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
111+
#else
112+
float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
113+
#endif
114+
115+
// Note: for ir == 0, there is actually no need to calculate exp_diff
116+
if (ir != 0) {
117+
float exp_diff = std::exp(m_diff);
118+
l[irow] = exp_diff * l[irow] + rowsum;
119+
120+
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
121+
temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol];
122+
}
123+
} else {
124+
l[irow] = rowsum;
125+
// When ir == 0, there is no need to scale the old result because it is zero.
126+
}
127+
}
128+
MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
129+
CBLAS_TRANSPOSE::CblasNoTrans,
130+
row_size_q_capped,
131+
static_cast<size_t>(v_head_size),
132+
row_size_kv_capped,
133+
1.0f,
134+
intermediate,
135+
row_size_kv_capped,
136+
inputV,
137+
static_cast<size_t>(v_head_size),
138+
ir == 0 ? 0.0f : 1.0f,
139+
temp_output,
140+
static_cast<size_t>(v_head_size));
141+
}
142+
143+
float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size;
144+
ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx);
145+
// TODO: leverage advanced instruction sets
146+
for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) {
147+
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
148+
output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow];
149+
}
150+
output_row += num_heads * v_head_size;
151+
}
152+
}
153+
}
154+
155+
void
156+
MLASCALL
157+
MlasFlashAttention(
158+
MlasFlashAttentionThreadedArgs* args,
159+
MLAS_THREADPOOL* ThreadPool
160+
)
161+
{
162+
MlasExecuteThreaded(
163+
MlasFlashAttentionThreaded,
164+
static_cast<void *>(args),
165+
static_cast<std::ptrdiff_t>(args->thread_count),
166+
ThreadPool);
167+
}

onnxruntime/core/platform/env.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class Env {
147147

148148
virtual std::vector<LogicalProcessors> GetDefaultThreadAffinities() const = 0;
149149

150+
virtual int GetL2CacheSize() const = 0;
151+
150152
/// \brief Returns the number of micro-seconds since the Unix epoch.
151153
virtual uint64_t NowMicros() const {
152154
return env_time_->NowMicros();

onnxruntime/core/platform/posix/env.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ limitations under the License.
4343
#define ORT_USE_CPUINFO
4444
#endif
4545

46+
#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
47+
#include <sys/sysctl.h>
48+
#endif
49+
4650
#include "core/common/common.h"
4751
#include <gsl/gsl>
4852
#include "core/common/logging/logging.h"
@@ -302,6 +306,22 @@ class PosixEnv : public Env {
302306
return ret;
303307
}
304308

309+
int GetL2CacheSize() const override {
310+
#ifdef _SC_LEVEL2_CACHE_SIZE
311+
return static_cast<int>(sysconf(_SC_LEVEL2_CACHE_SIZE));
312+
#else
313+
int value = 0; // unknown
314+
#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE)
315+
int mib[2] = {CTL_HW, HW_L2CACHESIZE};
316+
size_t len = sizeof(value);
317+
if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) {
318+
return -1; // error
319+
}
320+
#endif
321+
return value;
322+
#endif
323+
}
324+
305325
void SleepForMicroseconds(int64_t micros) const override {
306326
while (micros > 0) {
307327
timespec sleep_time;

0 commit comments

Comments
 (0)