Skip to content

Commit fcd448a

Browse files
authored
[WebNN] Always create a new constant for zero_points (microsoft#25286)
MatMulNBits is a decomposed op in WebNN EP. Previously, we share the WebNN constant of zero_points if they have the same value and data type. However, this brings a lot of complexity for developers to fuse it back to MatMulNBits in the underlying WebNN implementation in Chromium. In this PR, we will always create a new constant for zero_points.
1 parent e63e053 commit fcd448a

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,25 @@ Status MatMulNBitsBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
100100
// x_zero_point has the same shape as x_scale
101101
const bool has_zero_points = TensorExists(input_defs, 3);
102102
emscripten::val x_zero_point = emscripten::val::undefined();
103+
emscripten::val zero_points_desc = emscripten::val::object();
104+
zero_points_desc.set("dataType", emscripten::val("uint4"));
105+
zero_points_desc.set("shape", x_scale_shape_array);
106+
zero_points_desc.set("dimensions", x_scale_shape_array);
103107
if (has_zero_points) {
104108
// zero_points is an initializer with data type 'uint8', we need to register it as 'uint4' WebNN constant
105109
const auto zero_points_tensor = *initializers.at(input_defs[3]->Name());
106-
emscripten::val zero_points_desc = emscripten::val::object();
107-
zero_points_desc.set("dataType", emscripten::val("uint4"));
108-
zero_points_desc.set("shape", x_scale_shape_array);
109-
zero_points_desc.set("dimensions", x_scale_shape_array);
110110
ORT_RETURN_IF_ERROR(model_builder.RegisterConstant(zero_points_tensor, x_zero_point, zero_points_desc, logger));
111111
} else {
112112
// zero_points' default value is 8, referred from CPU EP
113113
const int8_t default_zero_point = 8;
114-
x_zero_point = model_builder.CreateOrGetConstant<int8_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT4,
115-
default_zero_point,
116-
x_scale_shape);
114+
// Always create a new WebNN constant for zero_points to facilitate MatMulNBits fusion in Chromium
115+
auto num_elements = (Product(x_scale_shape) + 1) / 2;
116+
emscripten::val default_zero_point_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
117+
default_zero_point_buffer.call<void>("fill",
118+
emscripten::val(PackInt8ToUint8DoubledNibbles(
119+
default_zero_point, ONNX_NAMESPACE::TensorProto_DataType_UINT4)));
120+
x_zero_point =
121+
model_builder.GetBuilder().call<emscripten::val>("constant", zero_points_desc, default_zero_point_buffer);
117122
}
118123

119124
// DequantizeLinear

0 commit comments

Comments
 (0)