Skip to content

Commit d4e31dc

Browse files
Improve SimplifiedLayerNorm by using same techniques as SkipSimplifiedLayerNorm (microsoft#25850)
### Description Use similar shaders as SkipSimplifiedLayerNorm in SimplifiedLayerNorm, to fix the performance issues with SimplifiedLayerNorm. ### Motivation and Context Prior to this change, generation in Bitnet was bottlenecked on SimplifiedLayerNorm <img width="332" height="378" alt="image" src="https://github.com/user-attachments/assets/3bc16ac1-ef7d-46bf-b403-92fc9192a2df" /> with this change performance has now improved to match SkipSimplifiedLayerNorm <img width="699" height="179" alt="image" src="https://github.com/user-attachments/assets/30009d85-d5d9-4585-987a-b39ecf52e0b5" />
1 parent 1d07e94 commit d4e31dc

File tree

2 files changed

+132
-41
lines changed

2 files changed

+132
-41
lines changed

onnxruntime/core/providers/webgpu/nn/layer_norm.cc

Lines changed: 125 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ static TensorShape GetOverrideShape(const TensorShape& shape, int components) {
2626
}
2727

2828
Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
29-
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
29+
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
3030
shader.AddInput("scale", ShaderUsage::UseUniform);
3131
if (has_bias_) {
3232
shader.AddInput("bias", ShaderUsage::UseUniform);
@@ -39,35 +39,113 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
3939
shader.AddOutput("inv_std_dev_output", ShaderUsage::None);
4040
}
4141

42-
int components = x.NumComponents();
43-
std::string bias = (has_bias_) ? " + bias[j]" : "";
44-
std::string simpl1 = (simplified_) ? "" : " - mean * mean";
45-
std::string simpl2 = (simplified_) ? "" : " - mean";
46-
47-
shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n")
48-
<< "alias f32_val_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n";
49-
50-
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count")
51-
<< "let offset = global_idx * uniforms.norm_size_vectorized;\n"
52-
<< "var mean_vector = f32_val_t(0);\n"
53-
<< "var mean_square_vector = f32_val_t(0);\n"
54-
<< "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n"
55-
<< " let value = f32_val_t(x[h + offset]);\n"
56-
<< " mean_vector += value;\n"
57-
<< " mean_square_vector += value * value;\n"
58-
<< "}\n"
59-
<< "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n"
60-
<< "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n"
61-
<< "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n"
62-
<< " let f32input = f32_val_t(x[j + offset]);\n"
63-
<< " let f32scale = f32_val_t(scale[j]);\n"
64-
<< " y[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n"
65-
<< "}\n";
66-
if (has_mean_output_) {
67-
shader.MainFunctionBody() << "mean_output[global_idx] = mean;\n";
68-
}
69-
if (has_inv_std_dev_output_) {
70-
shader.MainFunctionBody() << "inv_std_dev_output[global_idx] = inv_std_dev;\n";
42+
std::string simpl1 = (simplified_) ? "" : "- mean * mean ";
43+
std::string simpl2 = (simplified_) ? "" : "- x_element_t(mean) ";
44+
45+
if (split_norm_dim_) {
46+
shader.AdditionalImplementation()
47+
<< "var<workgroup> sum_shared : array<f32, workgroup_size_x>;\n"
48+
<< "var<workgroup> sum_squared_shared : array<f32, workgroup_size_x>;\n";
49+
50+
shader.MainFunctionBody()
51+
<< " var sum_vec4 = vec4<f32>(0);\n"
52+
<< " var sum_squared_vec4 = vec4<f32>(0);\n"
53+
<< " var cur_input = x_value_t(0);\n"
54+
<< " for (var i: u32 = 0; i < uniforms.norm_size / (workgroup_size_x * 4); i++) {\n"
55+
<< " let input_offset = i * workgroup_size_x + local_idx;\n"
56+
<< " let input_value = x[input_offset];\n"
57+
<< " if (i == workgroup_idx) {\n"
58+
<< " cur_input = input_value;\n"
59+
<< " }\n"
60+
<< " let f32_value = vec4<f32>(input_value);\n"
61+
<< " sum_vec4 += f32_value;\n"
62+
<< " sum_squared_vec4 += f32_value * f32_value;\n"
63+
<< " }\n"
64+
<< " var sum = " << SumVector("sum_vec4", 4) << ";\n"
65+
<< " var sum_squared = " << SumVector("sum_squared_vec4", 4) << ";\n"
66+
<< " sum_shared[local_idx] = sum;\n"
67+
<< " sum_squared_shared[local_idx] = sum_squared;\n"
68+
<< " workgroupBarrier();\n"
69+
<< " var reduce_size : u32 = workgroup_size_x;\n"
70+
<< " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
71+
<< " reduce_size = curr_size + (reduce_size & 1);\n"
72+
<< " if (local_idx < curr_size) {\n"
73+
<< " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n"
74+
<< " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n"
75+
<< " }\n"
76+
<< " workgroupBarrier();\n"
77+
<< " }\n"
78+
<< " let mean = sum_shared[0] / f32(uniforms.norm_size);\n"
79+
<< " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n"
80+
<< " let offset = workgroup_idx * workgroup_size_x + local_idx;\n"
81+
<< " y[offset] = ((cur_input " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset]" << (has_bias_ ? " + bias[offset] " : "") << ");\n";
82+
83+
if (has_mean_output_) {
84+
shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n"
85+
<< " mean_output[global_idx / uniforms.norm_size] = mean;\n"
86+
<< " }\n";
87+
}
88+
if (has_inv_std_dev_output_) {
89+
shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n"
90+
<< " inv_std_dev_output[global_idx / uniforms.norm_size] = inv_std_dev;\n"
91+
<< " }\n";
92+
}
93+
} else {
94+
int components = x.NumComponents();
95+
std::string bias = (has_bias_) ? " + bias[offset1d + i] " : "";
96+
97+
shader.AdditionalImplementation()
98+
<< "alias f32_val_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
99+
<< "var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n"
100+
<< "var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n";
101+
102+
shader.MainFunctionBody()
103+
<< "let ix = local_idx;\n"
104+
<< "let iy = global_idx / workgroup_size_x;\n"
105+
<< "let norm_size_vectorized: u32 = uniforms.norm_size / uniforms.components;\n"
106+
<< "var stride = norm_size_vectorized / workgroup_size_x;\n"
107+
<< "let offset = ix * stride + iy * norm_size_vectorized;\n"
108+
<< "let offset1d = stride * ix;\n"
109+
<< "sum_shared[ix] = f32_val_t(0);\n"
110+
<< "sum_squared_shared[ix] = f32_val_t(0);\n"
111+
<< "if (ix == workgroup_size_x - 1) {\n"
112+
<< " stride = norm_size_vectorized - stride * ix;\n"
113+
<< "}\n"
114+
<< "for (var i: u32 = 0; i < stride; i++) {\n"
115+
<< " let input_value = x[offset + i];\n"
116+
<< " y[offset + i] = input_value;\n"
117+
<< " let f32_value = f32_val_t(input_value);\n"
118+
<< " sum_shared[ix] += f32_value;\n"
119+
<< " sum_squared_shared[ix] += f32_value * f32_value;\n"
120+
<< "}\n"
121+
<< "workgroupBarrier();\n"
122+
<< "var reduce_size : u32 = workgroup_size_x;\n"
123+
<< "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
124+
<< " reduce_size = curr_size + (reduce_size & 1);\n"
125+
<< " if (ix < curr_size) {\n"
126+
<< " sum_shared[ix] += sum_shared[ix + reduce_size];\n"
127+
<< " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n"
128+
<< " }\n"
129+
<< " workgroupBarrier();\n"
130+
<< "}\n"
131+
<< "let sum = sum_shared[0];\n"
132+
<< "let square_sum = sum_squared_shared[0];\n"
133+
<< "let mean = " << SumVector("sum", components) << " / f32(uniforms.norm_size);\n"
134+
<< "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n"
135+
<< "for (var i: u32 = 0; i < stride; i++) {\n"
136+
<< " y[offset + i] = (y[offset + i] " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset1d + i]" << bias << ";\n"
137+
<< "};\n";
138+
139+
if (has_mean_output_) {
140+
shader.MainFunctionBody() << "if (ix == 0) {\n"
141+
<< " mean_output[iy] = mean;\n"
142+
<< "}\n";
143+
}
144+
if (has_inv_std_dev_output_) {
145+
shader.MainFunctionBody() << "if (ix == 0) {\n"
146+
<< " inv_std_dev_output[iy] = inv_std_dev;\n"
147+
<< "}\n";
148+
}
71149
}
72150

73151
return Status::OK();
@@ -81,8 +159,6 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
81159

82160
const auto x_shape = x->Shape();
83161

84-
const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
85-
86162
const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions());
87163
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(axis));
88164
const int64_t norm_size = x_shape.SizeFromDimension(axis);
@@ -116,14 +192,19 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
116192
return Status::OK();
117193
}
118194

119-
LayerNormProgram program{bias != nullptr, is_fp16, simplified, mean != nullptr, inv_std_dev != nullptr};
195+
// Check if we should use split norm dimension optimization
196+
const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1;
197+
198+
LayerNormProgram program{bias != nullptr, simplified, mean != nullptr, inv_std_dev != nullptr, split_norm_dim};
120199

121-
program.CacheHint(components, simplified)
200+
program.CacheHint(components, simplified, split_norm_dim)
122201
.AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}})
123202
.AddInputs(
124203
{{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}})
125204
.AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}})
126-
.SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
205+
.AddUniformVariables({
206+
{static_cast<uint32_t>(components)},
207+
})
127208
.AddUniformVariables({
128209
{static_cast<uint32_t>(norm_count)},
129210
})
@@ -137,6 +218,15 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
137218
{static_cast<float>(epsilon_)},
138219
});
139220

221+
if (split_norm_dim) {
222+
const uint32_t workgroup_size_x = 128;
223+
const uint32_t dispatch_size_x = onnxruntime::narrow<uint32_t>(norm_size / (workgroup_size_x * components));
224+
program.SetDispatchGroupSize(dispatch_size_x, 1, 1)
225+
.SetWorkgroupSize(workgroup_size_x);
226+
} else {
227+
program.SetDispatchGroupSize(norm_count);
228+
}
229+
140230
if (bias != nullptr) {
141231
program.AddInput(
142232
{bias, ProgramTensorMetadataDependency::Type, GetOverrideShape(bias->Shape(), components), components});

onnxruntime/core/providers/webgpu/nn/layer_norm.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,29 @@ namespace webgpu {
1111

1212
class LayerNormProgram final : public Program<LayerNormProgram> {
1313
public:
14-
LayerNormProgram(bool has_bias, bool is_fp16, bool simplified, bool has_mean_output,
15-
bool has_inv_std_dev_output)
14+
LayerNormProgram(bool has_bias, bool simplified, bool has_mean_output,
15+
bool has_inv_std_dev_output, bool split_norm_dim = false)
1616
: Program{"LayerNorm"},
1717
has_bias_{has_bias},
18-
is_fp16_{is_fp16},
1918
simplified_{simplified},
2019
has_mean_output_{has_mean_output},
21-
has_inv_std_dev_output_{has_inv_std_dev_output} {}
20+
has_inv_std_dev_output_{has_inv_std_dev_output},
21+
split_norm_dim_{split_norm_dim} {}
2222

2323
Status GenerateShaderCode(ShaderHelper& sh) const override;
2424

25-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"norm_count", ProgramUniformVariableDataType::Uint32},
25+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"components", ProgramUniformVariableDataType::Uint32},
26+
{"norm_count", ProgramUniformVariableDataType::Uint32},
2627
{"norm_size", ProgramUniformVariableDataType::Uint32},
2728
{"norm_size_vectorized", ProgramUniformVariableDataType::Uint32},
2829
{"epsilon", ProgramUniformVariableDataType::Float32});
2930

3031
private:
3132
bool has_bias_;
32-
bool is_fp16_;
3333
bool simplified_;
3434
bool has_mean_output_;
3535
bool has_inv_std_dev_output_;
36+
bool split_norm_dim_;
3637
};
3738

3839
template <bool simplified>

0 commit comments

Comments
 (0)