Skip to content

Commit 47a5306

Browse files
bachelor-douJaswanth51
authored andcommitted
[CANN] Add a enable_cann_subgraph feature parameter (microsoft#25867)
### Description Add a `enable_cann_subgraph` feature parameter. this parameter controls whether graph splitting is performed and can help quickly identify issues in certain scenarios.
1 parent 6499977 commit 47a5306

File tree

6 files changed

+14
-5
lines changed

6 files changed

+14
-5
lines changed

include/onnxruntime/core/providers/cann/cann_provider_options.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct OrtCANNProviderOptions {
1515
onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena
1616
int enable_cann_graph; // Flag indicating if prioritizing the use of
1717
// CANN's graph-running capabilities
18+
int enable_cann_subgraph; // Flag indicating whether to generate subgraph
19+
// automaticly
1820
int dump_graphs; // Flag indicating if dumping graphs
1921
int dump_om_model; // Flag indicating if dumping om model
2022
std::string precision_mode; // Operator Precision Mode

onnxruntime/core/providers/cann/cann_execution_provider.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,17 +1266,16 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
12661266
// the single operator operation mode of CANN
12671267
if (info_.enable_cann_graph) {
12681268
std::vector<NodeIndex>&& unsupported_nodes = SupportONNXModel(graph_viewer);
1269-
1270-
if (unsupported_nodes.empty()) {
1271-
auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer);
1272-
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
1273-
} else {
1269+
if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) {
12741270
auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes);
12751271

12761272
for (const auto& partition : partitions) {
12771273
auto sub_graph = GetSubGraph(partition, graph_viewer);
12781274
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
12791275
}
1276+
} else {
1277+
auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer);
1278+
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
12801279
}
12811280
} else {
12821281
InlinedVector<NodeIndex> candidates;

onnxruntime/core/providers/cann/cann_execution_provider_info.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ constexpr const char* kDeviceId = "device_id";
2020
constexpr const char* kMemLimit = "npu_mem_limit";
2121
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
2222
constexpr const char* kEnableCannGraph = "enable_cann_graph";
23+
constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph";
2324
constexpr const char* kDumpGraphs = "dump_graphs";
2425
constexpr const char* kDumpOmModel = "dump_om_model";
2526
constexpr const char* kPrecisionMode = "precision_mode";
@@ -58,6 +59,7 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P
5859
cann::provider_option_names::kArenaExtendStrategy,
5960
arena_extend_strategy_mapping, info.arena_extend_strategy)
6061
.AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph)
62+
.AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph)
6163
.AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs)
6264
.AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model)
6365
.AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode)
@@ -74,6 +76,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution
7476
{cann::provider_option_names::kArenaExtendStrategy,
7577
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
7678
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
79+
{cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)},
7780
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
7881
{cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)},
7982
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},
@@ -89,6 +92,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid
8992
{cann::provider_option_names::kArenaExtendStrategy,
9093
EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))},
9194
{cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)},
95+
{cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)},
9296
{cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)},
9397
{cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)},
9498
{cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)},

onnxruntime/core/providers/cann/cann_execution_provider_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct CANNExecutionProviderInfo {
1818
size_t npu_mem_limit{std::numeric_limits<size_t>::max()};
1919
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
2020
bool enable_cann_graph{true};
21+
bool enable_cann_subgraph{false};
2122
bool dump_graphs{false};
2223
bool dump_om_model{true};
2324
std::string precision_mode;

onnxruntime/core/providers/cann/cann_provider_factory.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ struct CANN_Provider : Provider {
7676
info.npu_mem_limit = params->npu_mem_limit;
7777
info.arena_extend_strategy = params->arena_extend_strategy;
7878
info.enable_cann_graph = params->enable_cann_graph != 0;
79+
info.enable_cann_subgraph = params->enable_cann_subgraph != 0;
7980
info.dump_graphs = params->dump_graphs != 0;
8081
info.dump_om_model = params->dump_om_model != 0;
8182
info.precision_mode = params->precision_mode;
@@ -94,6 +95,7 @@ struct CANN_Provider : Provider {
9495
cann_options.npu_mem_limit = internal_options.npu_mem_limit;
9596
cann_options.arena_extend_strategy = internal_options.arena_extend_strategy;
9697
cann_options.enable_cann_graph = internal_options.enable_cann_graph;
98+
cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph;
9799
cann_options.dump_graphs = internal_options.dump_graphs;
98100
cann_options.dump_om_model = internal_options.dump_om_model;
99101
cann_options.precision_mode = internal_options.precision_mode;

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2902,6 +2902,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider
29022902
options->npu_mem_limit = SIZE_MAX;
29032903
options->arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(0);
29042904
options->enable_cann_graph = 1;
2905+
options->enable_cann_subgraph = 0;
29052906
options->dump_graphs = 0;
29062907
options->dump_om_model = 1;
29072908
options->default_memory_arena_cfg = nullptr;

0 commit comments

Comments
 (0)