Skip to content

Commit 1dd6da9

Browse files
prathikrjavier-intel
authored andcommitted
[WebGPU EP] bug fix for convolution operator (microsoft#25000)
`is_channels_last` is being passed to MatMulProgram but not to MatMulNaiveProgram causing issues for musicgen model
1 parent 017a7c0 commit 1dd6da9

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
5555
std::string process_bias;
5656
if (has_bias_) {
5757
shader.AddInput("bias", ShaderUsage::UseUniform);
58-
process_bias = is_channels_last_ ? "value += output_value_t(bias[col])" : "value += output_value_t(bias[row + i]);";
58+
process_bias = is_channels_last_ ? "value += output_value_t(bias[col]);" : "value += output_value_t(bias[row + i]);";
5959
}
6060

6161
std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
192192
uint32_t output_size = static_cast<uint32_t>(output_shape.Size() / components / output_number);
193193
const size_t output_rank = matmul_output_shape.NumDimensions();
194194
TensorShape outer_dims = output_rank > 2 ? matmul_output_shape.Slice(0, output_rank - 2) : TensorShape({});
195-
MatMulNaiveProgram program(activation_, output_rank, output_number, has_bias);
195+
MatMulNaiveProgram program(activation_, output_rank, output_number, has_bias, is_channels_last);
196196
program
197197
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
198198
.AddInputs({{matmul_inputs[0], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[0], a_components), int(a_components)},

0 commit comments

Comments
 (0)