@@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
6969 return context.RunProgram (program);
7070};
7171
72- void InitVarStub (std::ostringstream& ss, const Tensor* seqlen_k ) {
73- if (seqlen_k != nullptr ) {
72+ void InitVarStub (std::ostringstream& ss, bool has_seqlen_k ) {
73+ if (has_seqlen_k ) {
7474 ss << " total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n " ;
7575 ss << " var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n " ;
7676 } else {
@@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
8787 if (has_attention_bias_) {
8888 shader.AddInput (" attention_bias" , ShaderUsage::UseUniform);
8989 }
90- if (seqlen_k_ != nullptr ) {
90+ if (has_seqlen_k_ ) {
9191 shader.AddInput (" seqlen_k" , ShaderUsage::UseUniform);
9292 }
9393 shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
@@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
107107 << " let sequence_length = uniforms.M;\n "
108108 << " var total_sequence_length = uniforms.N;\n " ;
109109 std::ostringstream oss;
110- InitVarStub (oss, seqlen_k_ );
110+ InitVarStub (oss, has_seqlen_k_ );
111111 shader.MainFunctionBody () << oss.str ();
112112 shader.MainFunctionBody () << " let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n " ;
113113 if (has_present_key_) {
@@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
182182 const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1 );
183183
184184 AttentionProbsProgram program{" AttentionProbs" , feed_past_key, has_present_key, has_attention_bias, tile_size,
185- components, parameters.is_first_prompt_ , seqlen_k, parameters.past_present_share_buffer_ };
185+ components, parameters.is_first_prompt_ , seqlen_k != nullptr , parameters.past_present_share_buffer_ };
186186 program.AddInputs ({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
187187 {K, ProgramTensorMetadataDependency::TypeAndRank, components}});
188188 if (feed_past_key) {
@@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
224224}
225225
226226Status InPlaceSoftmaxProgram::GenerateShaderCode (ShaderHelper& shader) const {
227- if (seqlen_k_ ) {
227+ if (has_seqlen_k_ ) {
228228 shader.AddInput (" seqlen_k" , ShaderUsage::UseUniform);
229229 }
230+ if (has_head_sink_) {
231+ shader.AddInput (" head_sink" , ShaderUsage::UseUniform);
232+ }
230233 shader.AddOutput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
231234 shader.AdditionalImplementation () << " var<workgroup> thread_max: array<f32, " << work_group_size_ << " >;\n "
232235 << " var<workgroup> thread_sum: array<f32, " << work_group_size_ << " >;\n "
233236 << " alias f32_val_t = " << (components_ == 4 ? " vec4<f32>" : (components_ == 2 ? " vec2<f32>" : " f32" )) << " ;\n " ;
234237 shader.MainFunctionBody () << " let sequence_length = uniforms.sequence_length;\n "
235238 << " let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n "
239+ << " let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n "
236240 << " var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << " ;\n " ;
237241 std::ostringstream oss;
238- InitVarStub (oss, seqlen_k_ );
242+ InitVarStub (oss, has_seqlen_k_ );
239243 shader.MainFunctionBody () << oss.str ()
240244 << " let local_offset = local_idx * uniforms.elements_per_thread;\n "
241245 << " let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n "
242- << " let seq_causal_length = " << (seqlen_k_ ? " past_sequence_length + workgroup_idx % sequence_length + 1" : " uniforms.total_sequence_length_comp" ) << " ;\n "
246+ << " let seq_causal_length = " << (has_seqlen_k_ ? " past_sequence_length + workgroup_idx % sequence_length + 1" : " uniforms.total_sequence_length_comp" ) << " ;\n "
243247 << " var thread_max_vector = f32_val_t(-3.402823e+38f);\n "
244248 << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n "
245249 << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n "
246250 << " }\n "
247251 << " thread_max[local_idx] = " << (components_ == 4 ? " max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? " max(thread_max_vector.x, thread_max_vector.y)" : " thread_max_vector" )) << " ;\n "
248- << " workgroupBarrier();\n "
249- << " var max_value = f32(-3.402823e+38f);\n "
250- << " for (var i = 0u; i < " << work_group_size_ << " ; i++) {\n "
252+ << " workgroupBarrier();\n " ;
253+
254+ if (has_head_sink_) {
255+ // Handle head sink
256+ shader.MainFunctionBody () << " let sink_value: f32 = head_sink[head_idx];\n "
257+ << " var max_value = sink_value;\n " ;
258+ } else if (use_smooth_softmax_) {
259+ shader.MainFunctionBody () << " var max_value: f32 = 0.0;\n " ;
260+ } else {
261+ shader.MainFunctionBody () << " var max_value = f32(-3.402823e+38f);\n " ;
262+ }
263+
264+ shader.MainFunctionBody () << " for (var i = 0u; i < " << work_group_size_ << " ; i++) {\n "
251265 << " max_value = max(thread_max[i], max_value);\n "
252266 << " }\n "
253267 << " var sum_vector = f32_val_t(0);\n "
@@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
259273 << " var sum: f32 = 0;\n "
260274 << " for (var i = 0u; i < " << work_group_size_ << " ; i++) {\n "
261275 << " sum += thread_sum[i]\n ;"
262- << " }\n "
263- << " if (sum == 0) {\n "
276+ << " }\n " ;
277+
278+ if (has_head_sink_) {
279+ shader.MainFunctionBody () << " sum += exp(sink_value - max_value);\n " ;
280+ } else if (use_smooth_softmax_) {
281+ shader.MainFunctionBody () << " sum += exp(-max_value);\n " ;
282+ }
283+
284+ shader.MainFunctionBody () << " if (sum == 0) {\n "
264285 << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n "
265286 << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n "
266287 << " }\n "
@@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
270291 << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n "
271292 << " }\n "
272293 << " }\n " ;
273- if (seqlen_k_ ) {
294+ if (has_seqlen_k_ ) {
274295 shader.MainFunctionBody () << " for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n "
275296 << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n "
276297 << " }\n " ;
@@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
280301}
281302
282303Status ComputeInPlaceSoftmax (onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
283- const Tensor* seqlen_k, bool is_first_prompt) {
304+ const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink ) {
284305 const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1 ));
285306 int work_group_size = 64 ;
286307 const int total_sequence_length_comp = (total_sequence_length + components - 1 ) / components;
@@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
289310 }
290311 const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1 ) / work_group_size;
291312
292- InPlaceSoftmaxProgram program{" InPlaceSoftmax " , work_group_size, components, seqlen_k};
313+ InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr , head_sink != nullptr };
293314 if (seqlen_k != nullptr ) {
294315 program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
295316 }
317+ if (head_sink != nullptr ) {
318+ program.AddInput ({head_sink, ProgramTensorMetadataDependency::Type});
319+ }
296320 program.AddOutputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
297- .CacheHint (work_group_size)
321+ .CacheHint (work_group_size, use_smooth_softmax )
298322 .SetDispatchGroupSize (batch_size * num_heads * sequence_length)
299323 .SetWorkgroupSize (work_group_size)
300324 .AddUniformVariables ({{static_cast <uint32_t >(batch_size)},
@@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
443467
444468Status ApplyAttention (const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
445469 const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
446- WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
470+ WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) {
447471 const int output_count = std::min ({context.OutputCount (), 1 + (past_key != nullptr ? 1 : 0 ) + (past_value != nullptr ? 1 : 0 )});
448472 const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0 ;
449473 const int total_sequence_length =
@@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
457481 parameters, past_sequence_length, total_sequence_length, seqlen_k));
458482
459483 ORT_RETURN_IF_ERROR (ComputeInPlaceSoftmax (context, &probs,
460- parameters.batch_size_ , parameters.num_heads_ , parameters.past_sequence_length_ , parameters.sequence_length_ , total_sequence_length, seqlen_k, parameters.is_first_prompt_ ));
484+ parameters.batch_size_ , parameters.num_heads_ , parameters.past_sequence_length_ , parameters.sequence_length_ , total_sequence_length, seqlen_k, parameters.is_first_prompt_ , parameters. use_smooth_softmax_ , head_sink ));
461485
462486 ORT_RETURN_IF_ERROR (ComputeVxAttentionScore (context, output_count, &probs, V, past_value, output, present_value,
463487 parameters, past_sequence_length, total_sequence_length, seqlen_k));
0 commit comments