@@ -337,6 +337,18 @@ static bool CheckDQRuleSet(const NodeUnit& node_unit,
337337 }
338338}
339339
340+ static bool CheckQFeedsIntoQuantizedOutput (const NodeUnit& node_unit,
341+ const std::unordered_map<std::string, std::string> graph_op_data_type) {
342+ auto op_of_quantized_layer = node_unit.Outputs ();
343+ for (auto itr : op_of_quantized_layer) {
344+ auto it = graph_op_data_type.find (itr.node_arg .Name ());
345+ if (it != graph_op_data_type.end () && it->second == " tensor(uint8)" ) {
346+ return true ;
347+ }
348+ }
349+ return false ;
350+ }
351+
340352static bool CheckQRuleSet (const NodeUnit& node_unit,
341353 const Node* q_node,
342354 const onnxruntime::GraphViewer& src_graph,
@@ -347,6 +359,12 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
347359 const auto & target_node = node_unit.GetNode ();
348360 auto op_type = node_unit.OpType ();
349361
362+ auto op = src_graph.GetOutputs ();
363+ std::unordered_map<std::string, std::string> graph_op_data_type;
364+ for (auto & ops : op) {
365+ graph_op_data_type[src_graph.GetNodeArg (ops->Name ())->Name ()] = ops->Type ()->data ();
366+ }
367+
350368 // If UInt16 Q, don't keep it
351369 if (GetQDQDataType (q_node) == DT_UINT16 || GetQDQDataType (q_node) == DT_INT16) {
352370 reason = SkipReason::Int16QDQ;
@@ -359,6 +377,8 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
359377 } else if (op_type == " Add" ) {
360378 // Add keeps all Qs
361379 return true ;
380+ } else if (CheckQFeedsIntoQuantizedOutput (node_unit, graph_op_data_type)) {
381+ return true ;
362382 } else {
363383 // Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list
364384 return IsNextTargetNodeOfQValid (q_node, &target_node, src_graph, {" Conv" , " Add" , " MatMul" }, false );
0 commit comments