Skip to content

Commit 6d28e2d

Browse files
authored
[webgpu] support smooth softmax for non-FA GQA implementation (microsoft#25285)
### Description support smooth softmax for non-FA GQA implementation This change depends on: - microsoft#25269 Work items: - [x] support smooth softmax - [x] support bias - [x] support head sink (per-head smooth softmax) The following will not be included in this PR: - support for FlashAttention - support sliding window
1 parent f0097fc commit 6d28e2d

File tree

7 files changed

+93
-32
lines changed

7 files changed

+93
-32
lines changed

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
5454
const Tensor* sin_cache = context->Input<Tensor>(8);
5555
const Tensor* position_ids = context->Input<Tensor>(9);
5656
const Tensor* attention_bias = context->Input<Tensor>(10);
57+
const Tensor* head_sink = context->Input<Tensor>(11);
5758

5859
GroupQueryAttentionParameters parameters = {};
5960
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
@@ -73,6 +74,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
7374

7475
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
7576
attention_bias,
77+
head_sink,
7678
parameters));
7779

7880
const int batch_size = parameters.batch_size;

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ Status CheckInputs(const T* query,
340340
template <typename T = Tensor>
341341
Status CheckCustomAttentionInputs(const T* position_ids,
342342
const T* attention_bias,
343+
const T* head_sink,
343344
const GroupQueryAttentionParameters& parameters) {
344345
if (position_ids != nullptr) {
345346
const auto& pos_ids_shape = position_ids->Shape();
@@ -377,6 +378,23 @@ Status CheckCustomAttentionInputs(const T* position_ids,
377378
}
378379
}
379380

381+
if (head_sink != nullptr) {
382+
if (parameters.use_smooth_softmax) {
383+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
384+
"head_sink should not be provided when use_smooth_softmax is true.");
385+
}
386+
387+
const auto& head_sink_shape = head_sink->Shape();
388+
if (head_sink_shape.NumDimensions() != 1) {
389+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor");
390+
}
391+
392+
if (head_sink_shape[0] != parameters.num_heads) {
393+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
394+
"head_sink dimension 0 must be equal to the num heads, got ", head_sink_shape[0]);
395+
}
396+
}
397+
380398
return Status::OK();
381399
}
382400

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
6969
return context.RunProgram(program);
7070
};
7171

72-
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
73-
if (seqlen_k != nullptr) {
72+
void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
73+
if (has_seqlen_k) {
7474
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
7575
ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n";
7676
} else {
@@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
8787
if (has_attention_bias_) {
8888
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
8989
}
90-
if (seqlen_k_ != nullptr) {
90+
if (has_seqlen_k_) {
9191
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
9292
}
9393
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
@@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
107107
<< "let sequence_length = uniforms.M;\n"
108108
<< "var total_sequence_length = uniforms.N;\n";
109109
std::ostringstream oss;
110-
InitVarStub(oss, seqlen_k_);
110+
InitVarStub(oss, has_seqlen_k_);
111111
shader.MainFunctionBody() << oss.str();
112112
shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
113113
if (has_present_key_) {
@@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
182182
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
183183

184184
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
185-
components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
185+
components, parameters.is_first_prompt_, seqlen_k != nullptr, parameters.past_present_share_buffer_};
186186
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
187187
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
188188
if (feed_past_key) {
@@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
224224
}
225225

226226
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
227-
if (seqlen_k_) {
227+
if (has_seqlen_k_) {
228228
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
229229
}
230+
if (has_head_sink_) {
231+
shader.AddInput("head_sink", ShaderUsage::UseUniform);
232+
}
230233
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
231234
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
232235
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
233236
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
234237
shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
235238
<< "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
239+
<< "let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n"
236240
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
237241
std::ostringstream oss;
238-
InitVarStub(oss, seqlen_k_);
242+
InitVarStub(oss, has_seqlen_k_);
239243
shader.MainFunctionBody() << oss.str()
240244
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
241245
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
242-
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
246+
<< "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
243247
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
244248
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
245249
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
246250
<< "}\n"
247251
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
248-
<< "workgroupBarrier();\n"
249-
<< "var max_value = f32(-3.402823e+38f);\n"
250-
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
252+
<< "workgroupBarrier();\n";
253+
254+
if (has_head_sink_) {
255+
// Handle head sink
256+
shader.MainFunctionBody() << "let sink_value: f32 = head_sink[head_idx];\n"
257+
<< "var max_value = sink_value;\n";
258+
} else if (use_smooth_softmax_) {
259+
shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n";
260+
} else {
261+
shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n";
262+
}
263+
264+
shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
251265
<< " max_value = max(thread_max[i], max_value);\n"
252266
<< "}\n"
253267
<< "var sum_vector = f32_val_t(0);\n"
@@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
259273
<< "var sum: f32 = 0;\n"
260274
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
261275
<< " sum += thread_sum[i]\n;"
262-
<< "}\n"
263-
<< "if (sum == 0) {\n"
276+
<< "}\n";
277+
278+
if (has_head_sink_) {
279+
shader.MainFunctionBody() << "sum += exp(sink_value - max_value);\n";
280+
} else if (use_smooth_softmax_) {
281+
shader.MainFunctionBody() << "sum += exp(-max_value);\n";
282+
}
283+
284+
shader.MainFunctionBody() << "if (sum == 0) {\n"
264285
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
265286
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
266287
<< " }\n"
@@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
270291
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
271292
<< " }\n"
272293
<< "}\n";
273-
if (seqlen_k_) {
294+
if (has_seqlen_k_) {
274295
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
275296
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
276297
<< "}\n";
@@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
280301
}
281302

282303
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
283-
const Tensor* seqlen_k, bool is_first_prompt) {
304+
const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) {
284305
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
285306
int work_group_size = 64;
286307
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
@@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
289310
}
290311
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;
291312

292-
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
313+
InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr};
293314
if (seqlen_k != nullptr) {
294315
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
295316
}
317+
if (head_sink != nullptr) {
318+
program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
319+
}
296320
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
297-
.CacheHint(work_group_size)
321+
.CacheHint(work_group_size, use_smooth_softmax)
298322
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
299323
.SetWorkgroupSize(work_group_size)
300324
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
@@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
443467

444468
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
445469
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
446-
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
470+
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) {
447471
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
448472
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
449473
const int total_sequence_length =
@@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
457481
parameters, past_sequence_length, total_sequence_length, seqlen_k));
458482

459483
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
460-
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));
484+
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink));
461485

462486
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
463487
parameters, past_sequence_length, total_sequence_length, seqlen_k));

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
3434
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
3535
public:
3636
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
37-
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
38-
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
37+
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false)
38+
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
3939
}
4040

4141
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -62,15 +62,15 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
6262
bool has_attention_bias_;
6363
int tile_size_;
6464
int components_;
65-
const Tensor* seqlen_k_;
65+
bool has_seqlen_k_;
6666
bool past_present_share_buffer_;
6767
bool is_first_prompt_;
6868
};
6969

7070
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
7171
public:
72-
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
73-
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
72+
InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)
73+
: Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) {
7474
}
7575

7676
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -86,7 +86,9 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
8686
private:
8787
int work_group_size_;
8888
int components_;
89-
const Tensor* seqlen_k_;
89+
bool use_smooth_softmax_;
90+
bool has_seqlen_k_;
91+
bool has_head_sink_;
9092
};
9193

9294
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {

onnxruntime/contrib_ops/webgpu/bert/attention_common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
123123

124124
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
125125
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
126-
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
126+
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
127+
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr);
127128

128129
} // namespace webgpu
129130
} // namespace contrib

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
382382
// sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi)
383383
// o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i
384384
//
385+
386+
// TODO: support smooth softmax and head_sink
385387
shader.MainFunctionBody() << R"MAIN_FN(
386388
var local_max_temp = max(qk_1, qk_2);
387389
if (sg_size > 8)

0 commit comments

Comments
 (0)