Skip to content

Commit 0ca313b

Browse files
authored
Relax WeightBiasQuantization constraint for larger QDQ node group (microsoft#25673)
### Description Relax WeightBiasQuantization constraint for larger QDQ node group ### Motivation and Context The transformer `WeightBiasQuantization` quantizes float weights on `Q -> DQ -> Conv/ConvTranspose/Gemm's Weights -> Q-> DQ` sequence; The check on `Weights -> Q` (`children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName`) is an issue due to it would skip quantization for many common patterns such as unfused activations followed by `Conv` (`DQ - Conv -> ReLU -> Q`). It's actually unnecessary to check ending Q here (the fold can happen anyway without changing model semantics). However, in order to minimize the current behavior change, this PR simply extend the pattern to include single path (no branch), type-preserving path lead to `Q` to enable more quantization support.
1 parent 3608cb2 commit 0ca313b

File tree

7 files changed

+93
-13
lines changed

7 files changed

+93
-13
lines changed

onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,39 @@
1313

1414
namespace onnxruntime {
1515

16+
/**
17+
* Checks whether or not the output path from a given node leads to a QuantizeLinear op, optionally, with no
18+
* branching ReLU or Clip op in between. See also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc.
19+
*
20+
* @param node The starting node to check the output path from.
21+
* @param graph The graph containing the nodes.
22+
*
23+
* @return true if the path exist, false otherwise.
24+
*/
25+
static bool IsNoBranchPathToQuantizeLinear(const Node& node, const Graph& graph) {
26+
const Node* current = &node;
27+
while (true) {
28+
// Conv / ConvTranspose / Gemm produces single output
29+
if (current->OutputDefs().size() != 1) {
30+
return false;
31+
}
32+
const std::vector<const Node*>& consumers = graph.GetConsumerNodes(current->OutputDefs()[0]->Name());
33+
// Branching or no consumer: not eligible
34+
if (consumers.size() != 1) {
35+
return false;
36+
}
37+
const Node* consumer = consumers[0];
38+
if (consumer->OpType() == QDQ::QOpName) {
39+
return true;
40+
}
41+
// Allow ReLU or Clip, see also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc.
42+
if (consumer->OpType() != "Relu" && consumer->OpType() != "Clip") {
43+
return false;
44+
}
45+
current = consumer;
46+
}
47+
}
48+
1649
Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level,
1750
const logging::Logger& logger) const {
1851
const GraphViewer graph_viewer{graph};
@@ -43,11 +76,8 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph
4376
continue;
4477
}
4578

46-
// Require that the node's output is consumed by a single QuantizeLinear node.
47-
// Otherwise, if only the inputs are quantized, but not the output, then this node group would not
48-
// be considered a QDQ node unit anyway.
49-
std::vector<const Node*> children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name());
50-
if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) {
79+
// Check if the output path leads to QuantizeLinear with optionally ReLU or Clip op in between.
80+
if (!IsNoBranchPathToQuantizeLinear(node, graph)) {
5181
continue;
5282
}
5383

onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class BaseOpBuilder : public IOpBuilder {
238238
}
239239

240240
// Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end]
241-
void ReArranagePads(std::vector<uint32_t>& pads) const {
241+
void ReArrangePads(std::vector<uint32_t>& pads) const {
242242
auto pads_size = pads.size();
243243
auto middle_pos = pads_size / 2;
244244
std::vector<uint32_t> first_half(pads.begin(), pads.begin() + middle_pos);

onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ static Status GetOnnxConvType(const std::string& onnx_op_type, OnnxConvType& con
2424
} else {
2525
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unsupported ONNX convolution op type: ", onnx_op_type.c_str());
2626
}
27-
2827
return Status::OK();
2928
}
3029

@@ -171,7 +170,7 @@ Status ConvOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
171170
return ProcessConv2D3DInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation);
172171
}
173172

174-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D(rank 5), 2D (rank 4) or 1D (rank 3) inputs.");
173+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D (rank 5), 2D (rank 4) or 1D (rank 3) inputs.");
175174
}
176175

177176
Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper,
@@ -712,7 +711,7 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
712711
}
713712
}
714713

715-
ReArranagePads(pads);
714+
ReArrangePads(pads);
716715
uint32_t pad_size = narrow<uint32_t>(pads.size() / 2);
717716
QnnParamWrapper pad_amount_paramwrapper(node_unit.Index(), node_unit.Name(), QNN_OP_CONV_2D_PARAM_PAD_AMOUNT,
718717
{pad_size, 2}, std::move(pads));

onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap
193193
[](int64_t item) { return SafeInt<uint32_t>(item); });
194194
// Onnx format is begin_0, begin_1, ..., end_0, end_1, ...
195195
// Qnn format is begin_0, end_0, begin_1, end_1, ...
196-
ReArranagePads(pad_amount);
196+
ReArrangePads(pad_amount);
197197

198198
std::vector<uint32_t> input_shape;
199199
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0.");

onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper,
195195
}
196196
}
197197
}
198-
ReArranagePads(pad_amount);
198+
ReArrangePads(pad_amount);
199199

200200
// Param: rounding_mode.
201201
rounding_mode = node_helper.Get("ceil_mode", rounding_mode);

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na
158158
return false;
159159
}
160160

161-
// During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor
161+
// During graph partitioning, we only need to do op validation, it's not required to create Qnn graph tensor
162162
// We only need to create the Qnn graph tensor during Compile to create Qnn graph
163163
if (!do_op_validation) {
164164
std::string error_string;

onnxruntime/test/optimizer/qdq_transformer_test.cc

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5444,8 +5444,59 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) {
54445444
#endif
54455445
}
54465446

5447+
// Tests that the WeightBiasQuantization optimizer still processes nodes that contain a type-preserving no
5448+
// branch ReLU op to QuantizeLinear e.g., Q -> DQ -> Conv (w/ float weight initializer) -> ReLU -> Q -> DQ
5449+
TEST(QDQTransformerTests, WeightBiasQuantization_ConvWithReLU) {
5450+
auto test_case = [](bool use_contrib_qdq) {
5451+
auto build_test_case = [&](ModelTestBuilder& builder) {
5452+
NodeArg* input_fp32 = builder.MakeInput<float>({1, 1, 4, 4}, -1.0f, 1.0f);
5453+
NodeArg* weight_fp32 = builder.MakeInitializer<float>({2, 1, 3, 3}, -1.0f, 1.0f);
5454+
NodeArg* input_q = builder.MakeIntermediate();
5455+
NodeArg* input_dq = builder.MakeIntermediate();
5456+
NodeArg* conv_fp32 = builder.MakeIntermediate();
5457+
NodeArg* relu_fp32 = builder.MakeIntermediate();
5458+
NodeArg* relu_q = builder.MakeIntermediate();
5459+
NodeArg* relu_dq = builder.MakeOutput();
5460+
builder.AddQuantizeLinearNode<uint8_t>(input_fp32, 0.18f, static_cast<uint8_t>(127), input_q, use_contrib_qdq);
5461+
builder.AddDequantizeLinearNode<uint8_t>(input_q, 0.18f, static_cast<uint8_t>(127), input_dq, use_contrib_qdq);
5462+
auto& conv_node = builder.AddNode("Conv", {input_dq, weight_fp32}, {conv_fp32});
5463+
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
5464+
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
5465+
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
5466+
conv_node.AddAttribute("group", static_cast<int64_t>(1));
5467+
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
5468+
builder.AddNode("Relu", {conv_fp32}, {relu_fp32});
5469+
builder.AddQuantizeLinearNode<uint8_t>(relu_fp32, 0.69f, static_cast<uint8_t>(127), relu_q, use_contrib_qdq);
5470+
builder.AddDequantizeLinearNode<uint8_t>(relu_q, 0.69f, static_cast<uint8_t>(127), relu_dq, use_contrib_qdq);
5471+
};
5472+
5473+
// Conv's weights should be quantized and folded, one additional Q/DQ pair inserted for weight
5474+
auto check_transformed_graph = [](InferenceSessionWrapper& session) {
5475+
auto op_to_count = CountOpsInGraph(session.GetGraph());
5476+
EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 2 + 1);
5477+
EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1);
5478+
EXPECT_EQ(op_to_count["Conv"], 1);
5479+
EXPECT_EQ(op_to_count["Relu"], 1);
5480+
};
5481+
5482+
TransformerTester(build_test_case,
5483+
check_transformed_graph,
5484+
TransformerLevel::Default,
5485+
TransformerLevel::Level1,
5486+
/*opset_version=*/20,
5487+
/*per_sample_tolerance=*/0.01,
5488+
/*relative_per_sample_tolerance=*/0.01,
5489+
/*transformer=*/std::make_unique<WeightBiasQuantization>());
5490+
};
5491+
5492+
test_case(false);
5493+
#if !defined(DISABLE_CONTRIB_OPS)
5494+
test_case(true);
5495+
#endif
5496+
}
5497+
54475498
// Tests that the WeightBiasQuantization optimizer does not process nodes that do not
5448-
// already have an output that is consumed by a single QuantizeLinear node.
5499+
// already have an output that is consumed by a valid path to QuantizeLinear node.
54495500
TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) {
54505501
auto test_case = [](bool add_final_reshape) {
54515502
auto build_test_case = [&](ModelTestBuilder& builder) {

0 commit comments

Comments
 (0)