@@ -35,11 +35,13 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
3535
3636 std::vector<int64_t > input_shape;
3737 std::vector<int64_t > scale_shape;
38+ std::vector<uint32_t > zero_point_shape;
3839 ORT_RETURN_IF_NOT (GetShape (*input_defs[0 ], input_shape, logger), " Cannot get input shape" );
3940 ORT_RETURN_IF_NOT (GetShape (*input_defs[1 ], scale_shape, logger), " Cannot get scale shape" );
4041 int32_t input_type = 0 ;
4142 int32_t output_type = 0 ;
4243 int32_t zero_point_type = 0 ;
44+ bool has_zero_point = false ;
4345 ORT_RETURN_IF_NOT (GetType (*input_defs[0 ], input_type, logger), " Cannot get input data type" );
4446 ORT_RETURN_IF_NOT (GetType (*output_defs[0 ], output_type, logger), " Cannot get output data type" );
4547
@@ -49,12 +51,55 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4951
5052 if (input_defs.size () == 3 && input_defs[2 ]->Exists ()) {
5153 zero_point = model_builder.GetOperand (node.InputDefs ()[2 ]->Name ());
54+ has_zero_point = true ;
5255 } else {
5356 // DequantizeLinear: x_zero_point's data type equals to input data type
5457 // QuantizeLinear: x_zero_point's data type equals to output data type
55- // WebNN requires the zero_point to have the same shape as the scale
5658 zero_point_type = op_type == " DequantizeLinear" ? input_type : output_type;
57- const auto zero_point_shape = GetVecUint32FromVecInt64 (scale_shape);
59+ }
60+
61+ const auto input_rank = input_shape.size ();
62+ NodeAttrHelper helper (node);
63+ int32_t block_size = helper.Get (" block_size" , 0 );
64+ int32_t axis = helper.Get (" axis" , 1 );
65+ if (axis < 0 ) {
66+ axis = SafeInt<int32_t >(HandleNegativeAxis (axis, input_rank));
67+ }
68+
69+ // For per-axis quantization/dequantization and axis is not equal to input_rank - 1,
70+ // we need to reshape the scale and zero_point tensors to make them broadcastable with the input tensor.
71+ if (scale_shape.size () == 1 && input_rank > 1 &&
72+ block_size == 0 && axis != static_cast <int32_t >(input_rank - 1 )) {
73+ // Insert ones before and after the axis dimension for broadcasting of scale tensor.
74+ std::vector<uint32_t > target_shape{SafeInt<uint32_t >(input_shape[axis])};
75+ target_shape.insert (target_shape.begin (), axis, 1 );
76+ target_shape.insert (target_shape.end (), input_rank - axis - 1 , 1 );
77+ // zero_point has the same shape as the scale tensor.
78+ zero_point_shape = target_shape;
79+ emscripten::val reshape_scale_options = emscripten::val::object ();
80+ reshape_scale_options.set (" label" , node.Name () + " _reshape_scale" );
81+ scale = model_builder.GetBuilder ().call <emscripten::val>(" reshape" ,
82+ scale,
83+ emscripten::val::array (target_shape),
84+ reshape_scale_options);
85+
86+ if (has_zero_point) {
87+ // Reshape the zero_point tensor too.
88+ emscripten::val reshape_zero_point_options = emscripten::val::object ();
89+ reshape_zero_point_options.set (" label" , node.Name () + " _reshape_zero_point" );
90+ zero_point = model_builder.GetBuilder ().call <emscripten::val>(" reshape" ,
91+ zero_point,
92+ emscripten::val::array (target_shape),
93+ reshape_zero_point_options);
94+ }
95+ }
96+
97+ // If zero_point is not provided, create a zero constant with the same shape as the scale tensor.
98+ if (!has_zero_point) {
99+ if (zero_point_shape.empty ()) {
100+ // zero_point has the same shape as the scale tensor.
101+ zero_point_shape = GetVecUint32FromVecInt64 (scale_shape);
102+ }
58103 zero_point = model_builder.GetZeroConstant (zero_point_type, zero_point_shape);
59104 }
60105
0 commit comments