@@ -26,7 +26,7 @@ static TensorShape GetOverrideShape(const TensorShape& shape, int components) {
2626}
2727
2828Status LayerNormProgram::GenerateShaderCode (ShaderHelper& shader) const {
29- const auto & x = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
29+ const auto & x = shader.AddInput (" x" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
3030 shader.AddInput (" scale" , ShaderUsage::UseUniform);
3131 if (has_bias_) {
3232 shader.AddInput (" bias" , ShaderUsage::UseUniform);
@@ -39,35 +39,113 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
3939 shader.AddOutput (" inv_std_dev_output" , ShaderUsage::None);
4040 }
4141
42- int components = x.NumComponents ();
43- std::string bias = (has_bias_) ? " + bias[j]" : " " ;
44- std::string simpl1 = (simplified_) ? " " : " - mean * mean" ;
45- std::string simpl2 = (simplified_) ? " " : " - mean" ;
46-
47- shader.AdditionalImplementation () << " alias element_t = " << (is_fp16_ ? " f16;\n " : " f32;\n " )
48- << " alias f32_val_t = " << (components == 4 ? " vec4<f32>" : (components == 2 ? " vec2<f32>" : " f32" )) << " ;\n " ;
49-
50- shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.norm_count" )
51- << " let offset = global_idx * uniforms.norm_size_vectorized;\n "
52- << " var mean_vector = f32_val_t(0);\n "
53- << " var mean_square_vector = f32_val_t(0);\n "
54- << " for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n "
55- << " let value = f32_val_t(x[h + offset]);\n "
56- << " mean_vector += value;\n "
57- << " mean_square_vector += value * value;\n "
58- << " }\n "
59- << " let mean = " << SumVector (" mean_vector" , components) << " / f32(uniforms.norm_size);\n "
60- << " let inv_std_dev = inverseSqrt(" << SumVector (" mean_square_vector" , components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n "
61- << " for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n "
62- << " let f32input = f32_val_t(x[j + offset]);\n "
63- << " let f32scale = f32_val_t(scale[j]);\n "
64- << " y[j + offset] = x_value_t((f32input" << simpl2 << " ) * inv_std_dev * f32scale)" << bias << " ;\n "
65- << " }\n " ;
66- if (has_mean_output_) {
67- shader.MainFunctionBody () << " mean_output[global_idx] = mean;\n " ;
68- }
69- if (has_inv_std_dev_output_) {
70- shader.MainFunctionBody () << " inv_std_dev_output[global_idx] = inv_std_dev;\n " ;
42+ std::string simpl1 = (simplified_) ? " " : " - mean * mean " ;
43+ std::string simpl2 = (simplified_) ? " " : " - x_element_t(mean) " ;
44+
45+ if (split_norm_dim_) {
46+ shader.AdditionalImplementation ()
47+ << " var<workgroup> sum_shared : array<f32, workgroup_size_x>;\n "
48+ << " var<workgroup> sum_squared_shared : array<f32, workgroup_size_x>;\n " ;
49+
50+ shader.MainFunctionBody ()
51+ << " var sum_vec4 = vec4<f32>(0);\n "
52+ << " var sum_squared_vec4 = vec4<f32>(0);\n "
53+ << " var cur_input = x_value_t(0);\n "
54+ << " for (var i: u32 = 0; i < uniforms.norm_size / (workgroup_size_x * 4); i++) {\n "
55+ << " let input_offset = i * workgroup_size_x + local_idx;\n "
56+ << " let input_value = x[input_offset];\n "
57+ << " if (i == workgroup_idx) {\n "
58+ << " cur_input = input_value;\n "
59+ << " }\n "
60+ << " let f32_value = vec4<f32>(input_value);\n "
61+ << " sum_vec4 += f32_value;\n "
62+ << " sum_squared_vec4 += f32_value * f32_value;\n "
63+ << " }\n "
64+ << " var sum = " << SumVector (" sum_vec4" , 4 ) << " ;\n "
65+ << " var sum_squared = " << SumVector (" sum_squared_vec4" , 4 ) << " ;\n "
66+ << " sum_shared[local_idx] = sum;\n "
67+ << " sum_squared_shared[local_idx] = sum_squared;\n "
68+ << " workgroupBarrier();\n "
69+ << " var reduce_size : u32 = workgroup_size_x;\n "
70+ << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n "
71+ << " reduce_size = curr_size + (reduce_size & 1);\n "
72+ << " if (local_idx < curr_size) {\n "
73+ << " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n "
74+ << " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n "
75+ << " }\n "
76+ << " workgroupBarrier();\n "
77+ << " }\n "
78+ << " let mean = sum_shared[0] / f32(uniforms.norm_size);\n "
79+ << " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.norm_size) " << simpl1 << " + uniforms.epsilon);\n "
80+ << " let offset = workgroup_idx * workgroup_size_x + local_idx;\n "
81+ << " y[offset] = ((cur_input " << simpl2 << " ) * x_element_t(inv_std_dev) * scale[offset]" << (has_bias_ ? " + bias[offset] " : " " ) << " );\n " ;
82+
83+ if (has_mean_output_) {
84+ shader.MainFunctionBody () << " if (local_idx == 0 && workgroup_idx == 0) {\n "
85+ << " mean_output[global_idx / uniforms.norm_size] = mean;\n "
86+ << " }\n " ;
87+ }
88+ if (has_inv_std_dev_output_) {
89+ shader.MainFunctionBody () << " if (local_idx == 0 && workgroup_idx == 0) {\n "
90+ << " inv_std_dev_output[global_idx / uniforms.norm_size] = inv_std_dev;\n "
91+ << " }\n " ;
92+ }
93+ } else {
94+ int components = x.NumComponents ();
95+ std::string bias = (has_bias_) ? " + bias[offset1d + i] " : " " ;
96+
97+ shader.AdditionalImplementation ()
98+ << " alias f32_val_t = " << (components == 4 ? " vec4<f32>" : (components == 2 ? " vec2<f32>" : " f32" )) << " ;\n "
99+ << " var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n "
100+ << " var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n " ;
101+
102+ shader.MainFunctionBody ()
103+ << " let ix = local_idx;\n "
104+ << " let iy = global_idx / workgroup_size_x;\n "
105+ << " let norm_size_vectorized: u32 = uniforms.norm_size / uniforms.components;\n "
106+ << " var stride = norm_size_vectorized / workgroup_size_x;\n "
107+ << " let offset = ix * stride + iy * norm_size_vectorized;\n "
108+ << " let offset1d = stride * ix;\n "
109+ << " sum_shared[ix] = f32_val_t(0);\n "
110+ << " sum_squared_shared[ix] = f32_val_t(0);\n "
111+ << " if (ix == workgroup_size_x - 1) {\n "
112+ << " stride = norm_size_vectorized - stride * ix;\n "
113+ << " }\n "
114+ << " for (var i: u32 = 0; i < stride; i++) {\n "
115+ << " let input_value = x[offset + i];\n "
116+ << " y[offset + i] = input_value;\n "
117+ << " let f32_value = f32_val_t(input_value);\n "
118+ << " sum_shared[ix] += f32_value;\n "
119+ << " sum_squared_shared[ix] += f32_value * f32_value;\n "
120+ << " }\n "
121+ << " workgroupBarrier();\n "
122+ << " var reduce_size : u32 = workgroup_size_x;\n "
123+ << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n "
124+ << " reduce_size = curr_size + (reduce_size & 1);\n "
125+ << " if (ix < curr_size) {\n "
126+ << " sum_shared[ix] += sum_shared[ix + reduce_size];\n "
127+ << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n "
128+ << " }\n "
129+ << " workgroupBarrier();\n "
130+ << " }\n "
131+ << " let sum = sum_shared[0];\n "
132+ << " let square_sum = sum_squared_shared[0];\n "
133+ << " let mean = " << SumVector (" sum" , components) << " / f32(uniforms.norm_size);\n "
134+ << " let inv_std_dev = inverseSqrt(" << SumVector (" square_sum" , components) << " / f32(uniforms.norm_size) " << simpl1 << " + uniforms.epsilon);\n "
135+ << " for (var i: u32 = 0; i < stride; i++) {\n "
136+ << " y[offset + i] = (y[offset + i] " << simpl2 << " ) * x_element_t(inv_std_dev) * scale[offset1d + i]" << bias << " ;\n "
137+ << " };\n " ;
138+
139+ if (has_mean_output_) {
140+ shader.MainFunctionBody () << " if (ix == 0) {\n "
141+ << " mean_output[iy] = mean;\n "
142+ << " }\n " ;
143+ }
144+ if (has_inv_std_dev_output_) {
145+ shader.MainFunctionBody () << " if (ix == 0) {\n "
146+ << " inv_std_dev_output[iy] = inv_std_dev;\n "
147+ << " }\n " ;
148+ }
71149 }
72150
73151 return Status::OK ();
@@ -81,8 +159,6 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
81159
82160 const auto x_shape = x->Shape ();
83161
84- const bool is_fp16 = x->GetElementType () == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
85-
86162 const size_t axis = NormalizeAxis (axis_, x_shape.NumDimensions ());
87163 const uint32_t norm_count = onnxruntime::narrow<uint32_t >(x_shape.SizeToDimension (axis));
88164 const int64_t norm_size = x_shape.SizeFromDimension (axis);
@@ -116,14 +192,19 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
116192 return Status::OK ();
117193 }
118194
119- LayerNormProgram program{bias != nullptr , is_fp16, simplified, mean != nullptr , inv_std_dev != nullptr };
195+ // Check if we should use split norm dimension optimization
196+ const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1 ;
197+
198+ LayerNormProgram program{bias != nullptr , simplified, mean != nullptr , inv_std_dev != nullptr , split_norm_dim};
120199
121- program.CacheHint (components, simplified)
200+ program.CacheHint (components, simplified, split_norm_dim )
122201 .AddInputs ({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape (x->Shape (), components), components}})
123202 .AddInputs (
124203 {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape (scale->Shape (), components), components}})
125204 .AddOutputs ({{y, ProgramTensorMetadataDependency::None, GetOverrideShape (y->Shape (), components), components}})
126- .SetDispatchGroupSize ((norm_count + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
205+ .AddUniformVariables ({
206+ {static_cast <uint32_t >(components)},
207+ })
127208 .AddUniformVariables ({
128209 {static_cast <uint32_t >(norm_count)},
129210 })
@@ -137,6 +218,15 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
137218 {static_cast <float >(epsilon_)},
138219 });
139220
221+ if (split_norm_dim) {
222+ const uint32_t workgroup_size_x = 128 ;
223+ const uint32_t dispatch_size_x = onnxruntime::narrow<uint32_t >(norm_size / (workgroup_size_x * components));
224+ program.SetDispatchGroupSize (dispatch_size_x, 1 , 1 )
225+ .SetWorkgroupSize (workgroup_size_x);
226+ } else {
227+ program.SetDispatchGroupSize (norm_count);
228+ }
229+
140230 if (bias != nullptr ) {
141231 program.AddInput (
142232 {bias, ProgramTensorMetadataDependency::Type, GetOverrideShape (bias->Shape (), components), components});
0 commit comments