Skip to content

Commit b9b1a03

Browse files
authored
[WebNN] QDQ's axis should be used for broadcasting (microsoft#22721)
For per-axis quantization/dequantization, WebNN requires the scale and zero_point inputs to be broadcastable. Axis should be used for reshape these two inputs.
1 parent d3ad76b commit b9b1a03

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
3535

3636
std::vector<int64_t> input_shape;
3737
std::vector<int64_t> scale_shape;
38+
std::vector<uint32_t> zero_point_shape;
3839
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
3940
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape");
4041
int32_t input_type = 0;
4142
int32_t output_type = 0;
4243
int32_t zero_point_type = 0;
44+
bool has_zero_point = false;
4345
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type");
4446
ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type");
4547

@@ -49,12 +51,55 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4951

5052
if (input_defs.size() == 3 && input_defs[2]->Exists()) {
5153
zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name());
54+
has_zero_point = true;
5255
} else {
5356
// DequantizeLinear: x_zero_point's data type equals to input data type
5457
// QuantizeLinear: x_zero_point's data type equals to output data type
55-
// WebNN requires the zero_point to have the same shape as the scale
5658
zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type;
57-
const auto zero_point_shape = GetVecUint32FromVecInt64(scale_shape);
59+
}
60+
61+
const auto input_rank = input_shape.size();
62+
NodeAttrHelper helper(node);
63+
int32_t block_size = helper.Get("block_size", 0);
64+
int32_t axis = helper.Get("axis", 1);
65+
if (axis < 0) {
66+
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, input_rank));
67+
}
68+
69+
// For per-axis quantization/dequantization and axis is not equal to input_rank - 1,
70+
// we need to reshape the scale and zero_point tensors to make them broadcastable with the input tensor.
71+
if (scale_shape.size() == 1 && input_rank > 1 &&
72+
block_size == 0 && axis != static_cast<int32_t>(input_rank - 1)) {
73+
// Insert ones before and after the axis dimension for broadcasting of scale tensor.
74+
std::vector<uint32_t> target_shape{SafeInt<uint32_t>(input_shape[axis])};
75+
target_shape.insert(target_shape.begin(), axis, 1);
76+
target_shape.insert(target_shape.end(), input_rank - axis - 1, 1);
77+
// zero_point has the same shape as the scale tensor.
78+
zero_point_shape = target_shape;
79+
emscripten::val reshape_scale_options = emscripten::val::object();
80+
reshape_scale_options.set("label", node.Name() + "_reshape_scale");
81+
scale = model_builder.GetBuilder().call<emscripten::val>("reshape",
82+
scale,
83+
emscripten::val::array(target_shape),
84+
reshape_scale_options);
85+
86+
if (has_zero_point) {
87+
// Reshape the zero_point tensor too.
88+
emscripten::val reshape_zero_point_options = emscripten::val::object();
89+
reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point");
90+
zero_point = model_builder.GetBuilder().call<emscripten::val>("reshape",
91+
zero_point,
92+
emscripten::val::array(target_shape),
93+
reshape_zero_point_options);
94+
}
95+
}
96+
97+
// If zero_point is not provided, create a zero constant with the same shape as the scale tensor.
98+
if (!has_zero_point) {
99+
if (zero_point_shape.empty()) {
100+
// zero_point has the same shape as the scale tensor.
101+
zero_point_shape = GetVecUint32FromVecInt64(scale_shape);
102+
}
58103
zero_point = model_builder.GetZeroConstant(zero_point_type, zero_point_shape);
59104
}
60105

0 commit comments

Comments
 (0)