Skip to content

Commit e974fbe

Browse files
authored
Sahar/psu lora fix 2 (#788)
* Changed fix * Fix to omit subgraph * Commit a fix for cluster index len * Fixing the Warning with size_t on clusters * Loop Test fix
1 parent 8ecdbd0 commit e974fbe

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

onnxruntime/core/providers/openvino/ov_versions/capability.cc

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -166,33 +166,24 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
166166
auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters);
167167

168168
int no_of_clusters = 0;
169-
std::vector<NodeIndex> prev_cluster;
170-
bool try_next_cluster = false;
171-
169+
size_t cluster_index = 0;
170+
size_t total_clusters = connected_clusters.size();
172171
for (auto this_cluster : connected_clusters) {
173172
bool omit_subgraph = false;
174-
if (try_next_cluster) {
175-
// no need to check previous cluster
176-
for (auto idx : prev_cluster) {
177-
if ((std::find(this_cluster.begin(), this_cluster.end(), idx)) == this_cluster.end()) {
178-
this_cluster.emplace_back(idx);
179-
}
180-
}
181-
try_next_cluster = false;
182-
}
183173

184-
// If subgraph has less then three, graph is considered trivial unless its an epctx cluster
185-
if (!try_next_cluster && this_cluster.size() < 3) {
186-
bool is_epctx_node = false;
187-
for (auto node_idx : this_cluster) {
188-
if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext")
189-
is_epctx_node = true;
190-
}
191-
if (!is_epctx_node) {
192-
omit_subgraph = true;
193-
prev_cluster = this_cluster;
194-
try_next_cluster = true;
195-
}
174+
//auto id = this_cluster.at(0);
175+
if (this_cluster.size() == 1) {
176+
//check next cluster
177+
auto index = this_cluster.at(0);
178+
if (graph_viewer_.GetNode(index)->OpType() == "EPContext") {
179+
omit_subgraph=false;
180+
} else if(cluster_index < total_clusters-1) {
181+
bool append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[cluster_index+1]);
182+
if(append_node) {
183+
connected_clusters[cluster_index+1].emplace_back(index);
184+
}
185+
omit_subgraph=true;
186+
}
196187
}
197188

198189
std::vector<std::string> cluster_graph_inputs, cluster_inputs, cluster_outputs;
@@ -233,15 +224,17 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
233224
}
234225
}
235226
}
236-
if (omit_subgraph)
237-
continue;
238227

239228
/* In scenarios, when there are no inputs or all inputs being initializers,
240229
ConstantFolding optimization in onnxruntime pre-computes the value.*/
241-
if (!cluster_inputs.empty()) {
242-
AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result);
243-
no_of_clusters++;
230+
if (!omit_subgraph) {
231+
if (!cluster_inputs.empty()) {
232+
AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result);
233+
no_of_clusters++;
234+
}
244235
}
236+
237+
cluster_index = cluster_index+1;
245238
}
246239
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters;
247240
}

onnxruntime/core/providers/openvino/ov_versions/utils.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,26 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector<std::vec
153153
return connected_clusters;
154154
}
155155

156+
bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer,
157+
const NodeIndex curr_node_index,
158+
const std::vector<NodeIndex>& search_cluster) {
159+
160+
for(auto index: search_cluster) {
161+
auto curr_node = graph_viewer.GetNode(index);
162+
for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) {
163+
if((*node).Index() == curr_node_index)
164+
return true;
165+
}
166+
167+
for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) {
168+
if((*node).Index() == curr_node_index)
169+
return true;
170+
}
171+
}
172+
return false;
173+
}
174+
175+
156176
void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer,
157177
const std::vector<NodeIndex>& cluster,
158178
const std::unordered_set<std::string>& ng_required_initializers,

onnxruntime/core/providers/openvino/ov_versions/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ void IdentifyConnectedNodes(
4040
std::vector<std::vector<NodeIndex>>
4141
GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector<std::vector<NodeIndex>>& clusters);
4242

43+
bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer,
44+
const NodeIndex index,
45+
const std::vector<NodeIndex>& search_cluster);
46+
4347
void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer,
4448
const std::vector<NodeIndex>& cluster,
4549
const std::unordered_set<std::string>& ng_required_initializers,

onnxruntime/test/providers/cpu/controlflow/loop_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ TEST(Loop, SequenceAsLoopCarriedDependency) {
11621162
test.AddSeqOutput("loop_var_0_final", seq_output);
11631163

11641164
// Disable TensorRT on unsupported data type BOOL
1165-
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
1165+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
11661166
}
11671167

11681168
#if !defined(DISABLE_OPTIONAL_TYPE)

0 commit comments

Comments
 (0)