Skip to content

Commit 8c68aa1

Browse files
adrastogiAditya Rastogigithub-actions[bot]
authored
Add support for generating and validating compiled model compatibility info (microsoft#25749)
### Description This pull request introduces a new mechanism for validating compiled model compatibility with execution providers (EPs) in ONNX Runtime. It adds infrastructure for EPs to generate and store compatibility information in model metadata, and for the runtime to enforce compatibility checks during session initialization. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The APIs proposed in this PR address two requirements: 1. Apps that have an already pre-compiled model on device need a way to determine if the pre-compiled app is still valid (given the EPs / drivers / etc. on the system). 2. Apps may have many different pre-compiled versions of a model stored on a remote server, and want to figure out which of those models they should download for the device where they are running. ### Testing Validated that the new suite of tests passes cleanly. Created a private build of this ORT and the AMD Vitis EP. I stepped through the core logic (the EP doesn't have this support wired up as yet so there is no compatibility info written out) and for regression purposes, confirmed I could compile and run inferences through ResNet. --------- Co-authored-by: Aditya Rastogi <adityar@ntdev.microsoft.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d6e712c commit 8c68aa1

14 files changed

+675
-2
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class GraphOptimizerRegistry;
3636
#include "core/framework/framework_provider_common.h"
3737
#include "core/framework/stream_handles.h"
3838
#include "core/framework/tuning_context.h"
39+
#include "core/session/onnxruntime_c_api.h"
3940

4041
struct OrtEpDevice;
4142
struct OrtRunOptions;
@@ -322,6 +323,29 @@ class IExecutionProvider {
322323
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
323324
std::vector<NodeComputeInfo>& node_compute_funcs);
324325

326+
/**
327+
* Get the compatibility info for a compiled model.
328+
*
329+
* The execution provider determines this value, which denotes the compatibility of the compiled model with the EP.
330+
* This is stored in the model metadata under a key associated with the EP type.
331+
*/
332+
virtual std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const {
333+
// graph_viewer and model_metadata are not used in the default implementation.
334+
ORT_UNUSED_PARAMETER(graph_viewer);
335+
// Default implementation returns empty string
336+
return std::string();
337+
}
338+
339+
/**
340+
* Validate the compatibility of a compiled model with this execution provider.
341+
*/
342+
virtual common::Status ValidateCompiledModelCompatibilityInfo(const std::string& /*compatibility_info*/,
343+
OrtCompiledModelCompatibility& model_compatibility) const {
344+
// Default implementation indicates this EP does not support model compatibility validation
345+
model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
346+
return Status::OK();
347+
}
348+
325349
#endif
326350

327351
void SetLogger(const logging::Logger* logger) {

include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@
88

99
// Key for the execution provider version string. This should be available for all plugin EPs.
1010
static const char* const kOrtEpDevice_EpMetadataKey_Version = "version";
11+
12+
// Prefix for execution provider compatibility information stored in model metadata.
13+
// Used when generating EP context models to store compatibility strings for each EP.
14+
// Full key format: "ep_compatibility_info.<EP_TYPE>"
15+
static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info.";

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio
382382
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
383383
// Meant to be used with SetEpDynamicOptions
384384
// Specify the type of workload for this session.
385-
// Default: OS determines the scheduling priority and processor performance to service this workload. [Default]
386-
// Efficient: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
385+
// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default]
386+
// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
387387
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
388388

389389
// Disables model compilation during session initialization.
@@ -401,3 +401,10 @@ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload
401401
// - "0": EP compile is not disabled. [DEFAULT]
402402
// - "1": EP compile is disabled.
403403
static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile";
404+
405+
// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION.
406+
// "0": Allow execution with suboptimal performance. [DEFAULT]
407+
// "1": Fail session creation to require recompilation for optimal performance.
408+
// Note: UNSUPPORTED models always fail regardless of this setting.
409+
static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel =
410+
"session.fail_on_suboptimal_compiled_model";

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "core/graph/model.h"
2323
#include "core/graph/model_saving_options.h"
2424
#include "core/session/onnxruntime_session_options_config_keys.h"
25+
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
2526

2627
// uncomment this line to count non-CUDA ops in ONNX domain
2728
// #define COUNT_NON_CUDA_OPS
@@ -909,6 +910,34 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
909910
}
910911
}
911912

913+
// Generate EP compatibility strings for OrtEp types and add to model metadata
914+
// At this point, the graph has been populated with all the EPContext nodes
915+
{
916+
ORT_RETURN_IF_ERROR(ep_graph.Resolve());
917+
const GraphViewer graph_viewer(ep_graph);
918+
for (const auto& ep : execution_providers) {
919+
try {
920+
// Generate the compatibility string for this EP
921+
std::string compatibility_string = ep->GetCompiledModelCompatibilityInfo(graph_viewer);
922+
if (!compatibility_string.empty()) {
923+
// Create a unique key for this EP's compatibility info
924+
// Use format: "ep_compatibility_info.<EP_TYPE>"
925+
// All EPs in a session must have a unique Type() value, so this will be unique for the generated model
926+
std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep->Type();
927+
auto& model_metadata = ep_context_model.MetaData();
928+
auto [it, was_inserted] =
929+
model_metadata.insert_or_assign(metadata_key, compatibility_string);
930+
if (!was_inserted) {
931+
LOGS(logger, WARNING) << "Overwriting existing EP compatibility info for key: " << metadata_key << " (EP: " << ep->Type() << ")";
932+
}
933+
LOGS(logger, VERBOSE) << "Added EP compatibility info for " << ep->Type() << " with key: " << metadata_key;
934+
}
935+
} catch (const std::exception& ex) {
936+
LOGS(logger, WARNING) << "Failed to generate compatibility string for EP " << ep->Type() << ": " << ex.what();
937+
}
938+
}
939+
}
940+
912941
size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold;
913942
std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path;
914943
bool force_embed_external_ini = false;

onnxruntime/core/graph/model.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ const ModelMetaData& Model::MetaData() const noexcept {
361361
return model_metadata_;
362362
}
363363

364+
ModelMetaData& Model::MetaData() noexcept {
365+
return model_metadata_;
366+
}
367+
364368
Graph& Model::MainGraph() noexcept {
365369
return *graph_;
366370
}
@@ -377,6 +381,15 @@ ModelProto Model::ToProto() const {
377381
// out dense duplicates of sparse initializers and leave the original
378382
// proto intact.
379383
ModelProto result(model_proto_);
384+
385+
// Sync current model_metadata_ back to protobuf metadata_props
386+
result.clear_metadata_props();
387+
for (const auto& metadata : model_metadata_) {
388+
const gsl::not_null<StringStringEntryProto*> prop{result.add_metadata_props()};
389+
prop->set_key(metadata.first);
390+
prop->set_value(metadata.second);
391+
}
392+
380393
const auto& graph = *graph_;
381394
*(result.mutable_graph()) = graph.ToGraphProto();
382395
return result;
@@ -386,6 +399,15 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa
386399
const std::filesystem::path& file_path,
387400
const ModelSavingOptions& model_saving_options) const {
388401
ModelProto result(model_proto_);
402+
403+
// Sync current model_metadata_ back to protobuf metadata_props
404+
result.clear_metadata_props();
405+
for (const auto& metadata : model_metadata_) {
406+
const gsl::not_null<StringStringEntryProto*> prop{result.add_metadata_props()};
407+
prop->set_key(metadata.first);
408+
prop->set_value(metadata.second);
409+
}
410+
389411
const auto& graph = *graph_;
390412
*(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name,
391413
file_path,

onnxruntime/core/graph/model.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ class Model {
189189

190190
const ModelMetaData& MetaData() const noexcept;
191191

192+
ModelMetaData& MetaData() noexcept;
193+
192194
// Gets the path from which the model was loaded, if any.
193195
const std::filesystem::path& ModelPath() const noexcept { return model_path_; }
194196

onnxruntime/core/session/plugin_ep/ep_factory_internal.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr<EpFactoryInternalImpl> impl
2323
OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices;
2424
OrtEpFactory::CreateEp = Forward::CreateEp;
2525
OrtEpFactory::ReleaseEp = Forward::ReleaseEp;
26+
OrtEpFactory::ValidateCompiledModelCompatibilityInfo = Forward::ValidateCompiledModelCompatibilityInfo;
2627
OrtEpFactory::CreateAllocator = Forward::CreateAllocator;
2728
OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator;
2829
OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer;

onnxruntime/core/session/plugin_ep/ep_factory_internal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ class EpFactoryInternal : public OrtEpFactory {
8080
return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream);
8181
}
8282

83+
OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
84+
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
85+
return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
86+
}
87+
8388
// Function ORT calls to release an EP instance.
8489
void ReleaseEp(OrtEp* /*ep*/) noexcept {
8590
// we never create an OrtEp so we should never be trying to release one

onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class EpFactoryInternalImpl {
6262
return false;
6363
}
6464

65+
virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
66+
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
67+
ORT_UNUSED_PARAMETER(compatibility_info);
68+
// Default implementation: mark as not applicable
69+
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
70+
return nullptr;
71+
}
72+
6573
virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/,
6674
_In_opt_ const OrtKeyValuePairs* /*stream_options*/,
6775
_Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept {

onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,4 +644,35 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr
644644
registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost);
645645
}
646646
}
647+
648+
std::string PluginExecutionProvider::GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const {
649+
if (ort_ep_->GetCompiledModelCompatibilityInfo == nullptr) {
650+
// Plugin EP did not provide an implementation of this function, so we call a default implementation.
651+
return Base::GetCompiledModelCompatibilityInfo(graph_viewer);
652+
}
653+
std::unique_ptr<EpGraph> ep_graph = nullptr;
654+
auto ort_status = EpGraph::Create(graph_viewer, ep_graph);
655+
if (!ort_status.IsOK()) {
656+
LOGS(*GetLogger(), ERROR) << "Failed to create EpGraph: " << ort_status.ToString();
657+
return {};
658+
}
659+
// Call EP plugin's OrtEp::GenerateCompiledModelCompatibilityInfo() function.
660+
std::string compatibility_info_string;
661+
compatibility_info_string = ort_ep_->GetCompiledModelCompatibilityInfo(ort_ep_.get(), ep_graph.get());
662+
return compatibility_info_string;
663+
}
664+
665+
Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
666+
OrtCompiledModelCompatibility& model_compatibility) const {
667+
if (ep_factory_.ValidateCompiledModelCompatibilityInfo == nullptr) {
668+
// Plugin EP did not provide an implementation of this function, so we call a default implementation.
669+
return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
670+
}
671+
// Delegate to the EP factory's validation method
672+
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_,
673+
compatibility_info.c_str(),
674+
&model_compatibility)));
675+
return Status::OK();
676+
}
677+
647678
} // namespace onnxruntime

0 commit comments

Comments
 (0)