1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/optimizer/qdq_transformer/where_dummy_dq.h"
5+
6+ #include " core/framework/tensorprotoutils.h"
7+ #include " core/common/common.h"
8+ #include " core/util/qmath.h"
9+ #include " core/graph/graph_utils.h"
10+ #include " core/graph/graph_viewer.h"
11+ #include " core/optimizer/initializer.h"
12+ #include " core/optimizer/utils.h"
13+ #include " core/optimizer/qdq_transformer/qdq_util.h"
14+
15+ namespace onnxruntime {
16+ bool WhereDummyDq::SatisfyCondition (const Graph& graph, const Node& node) const {
17+ if (!(node.OpType () == " Where" )) {
18+ return false ;
19+ }
20+ const auto & where_inputs = node.InputDefs ();
21+ const Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
22+ const Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
23+
24+ bool is_p1_dq = (parent_node_1 && parent_node_1->OpType () == QDQ::DQOpName);
25+ bool is_p2_dq = (parent_node_2 && parent_node_2->OpType () == QDQ::DQOpName);
26+
27+ // WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input
28+ if (is_p1_dq && !parent_node_2) {
29+ return (where_inputs[2 ]->Shape ()->dim_size () == 0 );
30+ }
31+ if (!parent_node_1 && is_p2_dq) {
32+ return (where_inputs[1 ]->Shape ()->dim_size () == 0 );
33+ }
34+ return false ;
35+ }
36+
37+ Status WhereDummyDq::InsertDummyDQ (Node& node, Graph& graph, bool & modified, const logging::Logger& logger) const {
38+ const auto & where_inputs = node.InputDefs ();
39+ const Node* parent_node_1 = graph.GetProducerNode (where_inputs[1 ]->Name ());
40+ const Node* parent_node_2 = graph.GetProducerNode (where_inputs[2 ]->Name ());
41+
42+ // With SatisfyCondition, we must have one DQ and one initializer
43+ const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2;
44+ int const_idx = parent_node_1 ? 2 : 1 ;
45+
46+ const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr ;
47+ graph.GetInitializedTensor (dq_node->InputDefs ()[1 ]->Name (), dq_node_scale_proto);
48+ const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr ;
49+ graph.GetInitializedTensor (dq_node->InputDefs ()[2 ]->Name (), dq_node_zp_proto);
50+
51+ // Dummy data initializer.
52+ ONNX_NAMESPACE::TensorProto dummy_data_proto;
53+ dummy_data_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_data" ));
54+ // Set data type to dq node's zp dtype
55+ dummy_data_proto.set_data_type (dq_node_zp_proto->data_type ());
56+
57+ // Dummy zero point initializer.
58+ ONNX_NAMESPACE::TensorProto dummy_zp_proto;
59+ dummy_zp_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_zp" ));
60+ dummy_zp_proto.set_data_type (dq_node_zp_proto->data_type ());
61+
62+ switch (dummy_zp_proto.data_type ()) {
63+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
64+ int8_t zp = 0 ;
65+ int8_t dummy_data = 1 ;
66+ dummy_zp_proto.set_raw_data (&zp, 1 );
67+ dummy_data_proto.set_raw_data (&dummy_data, 1 );
68+ break ;
69+ }
70+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
71+ uint8_t zp = 0 ;
72+ uint8_t dummy_data = 1 ;
73+ dummy_zp_proto.set_raw_data (&zp, 1 );
74+ dummy_data_proto.set_raw_data (&dummy_data, 1 );
75+ break ;
76+ }
77+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
78+ int16_t zp = 0 ;
79+ int16_t dummy_data = 1 ;
80+ dummy_zp_proto.set_raw_data (&zp, 2 );
81+ dummy_data_proto.set_raw_data (&dummy_data, 2 );
82+ break ;
83+ }
84+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
85+ uint16_t zp = 0 ;
86+ uint16_t dummy_data = 1 ;
87+ dummy_zp_proto.set_raw_data (&zp, 2 );
88+ dummy_data_proto.set_raw_data (&dummy_data, 2 );
89+ break ;
90+ }
91+ default :
92+ LOGS (logger, WARNING) << " Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16" ;
93+ return Status::OK ();
94+ }
95+
96+ // Set dummy scale to the original value
97+ const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr ;
98+ graph.GetInitializedTensor (where_inputs[const_idx]->Name (), const_node_data_proto);
99+ Initializer initializer (graph, *const_node_data_proto, graph.ModelPath ());
100+ if (dq_node_scale_proto->data_type () != const_node_data_proto->data_type ()) {
101+ // WhereDummyDq fills the const value to the dummy DQ's scale
102+ LOGS (logger, WARNING) << " Currently only support existing DQ's scale with same datatype as scalar" ;
103+ return Status::OK ();
104+ }
105+
106+ // Dummy scale initializer.
107+ ONNX_NAMESPACE::TensorProto dummy_scale_proto;
108+ dummy_scale_proto.set_name (graph.GenerateNodeArgName (node.Name () + " _dummy_scale" ));
109+ dummy_scale_proto.set_data_type (dq_node_scale_proto->data_type ());
110+ switch (initializer.data_type ()) {
111+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
112+ float * where_const_scalar = initializer.data <float >();
113+ dummy_scale_proto.set_raw_data (where_const_scalar, sizeof (float ));
114+ break ;
115+ }
116+ default :
117+ LOGS (logger, WARNING) << " Currently support scalar with FLOAT" ;
118+ return Status::OK ();
119+ }
120+
121+ // Start editing the graph
122+ NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_data_proto);
123+ NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_scale_proto);
124+ NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData (graph, dummy_zp_proto);
125+
126+ ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto (*const_node_data_proto);
127+ dummy_dq_type_proto.mutable_tensor_type ()->set_elem_type (const_node_data_proto->data_type ());
128+ NodeArg& dummy_dq_arg =
129+ graph.GetOrCreateNodeArg (graph.GenerateNodeArgName (node.Name () + " _dummy_dq" ), &dummy_dq_type_proto);
130+ Node& dummy_dq_node =
131+ graph.AddNode (
132+ graph.GenerateNodeArgName (node.Name () + " _dummy_dq" ),
133+ QDQ::DQOpName,
134+ " DeQuantizeLinear from WhereDummyDq GraphTransformer" ,
135+ {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg},
136+ {&dummy_dq_arg},
137+ nullptr ,
138+ dq_node->Domain ());
139+
140+ node.MutableInputDefs ()[const_idx] = &dummy_dq_arg;
141+ if (graph.GetConsumerNodes (where_inputs[const_idx]->Name ()).size () == 0 ) {
142+ graph.RemoveInitializedTensor (where_inputs[const_idx]->Name ());
143+ }
144+ graph.AddEdge (dummy_dq_node.Index (), node.Index (), 0 , const_idx);
145+ modified = true ;
146+
147+ return Status::OK ();
148+ }
149+
150+ Status WhereDummyDq::ApplyImpl (Graph& graph, bool & modified, int graph_level, const logging::Logger& logger) const {
151+ const GraphViewer graph_viewer{graph};
152+ const auto & node_indices = graph_viewer.GetNodesInTopologicalOrder ();
153+ for (const auto node_idx : node_indices) {
154+ auto * node_ptr = graph.GetNode (node_idx);
155+ if (!node_ptr) {
156+ continue ;
157+ }
158+
159+ Node& node = *node_ptr;
160+ ORT_RETURN_IF_ERROR (Recurse (node, modified, graph_level, logger));
161+
162+ if (this ->SatisfyCondition (graph, node)) {
163+ ORT_RETURN_IF_ERROR (WhereDummyDq::InsertDummyDQ (node, graph, modified, logger));
164+ }
165+ }
166+
167+ return Status::OK ();
168+ }
169+ } // namespace onnxruntime
0 commit comments