Skip to content

Commit 3466f3b

Browse files
authored
Merge pull request #398 from intel/saurabh/fix_q_linear
fix graph output expects int8 dtype
2 parents 162e0c8 + 8965c66 commit 3466f3b

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,18 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit,
337337
}
338338
}
339339

340+
static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit,
341+
const std::unordered_map<std::string, std::string> graph_op_data_type) {
342+
auto op_of_quantized_layer = node_unit.Outputs();
343+
for (auto itr : op_of_quantized_layer) {
344+
auto it = graph_op_data_type.find(itr.node_arg.Name());
345+
if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") {
346+
return true;
347+
}
348+
}
349+
return false;
350+
}
351+
340352
static bool CheckQRuleSet(const NodeUnit& node_unit,
341353
const Node* q_node,
342354
const onnxruntime::GraphViewer& src_graph,
@@ -347,6 +359,12 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
347359
const auto& target_node = node_unit.GetNode();
348360
auto op_type = node_unit.OpType();
349361

362+
auto op = src_graph.GetOutputs();
363+
std::unordered_map<std::string, std::string> graph_op_data_type;
364+
for (auto& ops : op) {
365+
graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data();
366+
}
367+
350368
// If UInt16 Q, don't keep it
351369
if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) {
352370
reason = SkipReason::Int16QDQ;
@@ -359,6 +377,8 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
359377
} else if (op_type == "Add") {
360378
// Add keeps all Qs
361379
return true;
380+
} else if (CheckQFeedsIntoQuantizedOutput(node_unit, graph_op_data_type)) {
381+
return true;
362382
} else {
363383
// Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list
364384
return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false);

0 commit comments

Comments
 (0)