Skip to content

Commit c7997c2

Browse files
adrastogiAditya Rastogi
authored andcommitted
Language bindings for model compatibility API (microsoft#25878)
### Description This change builds on top of microsoft#25841 , and adds the scaffolding necessary to call into this API from C++ / C# / Python. ### Motivation and Context microsoft#25454 talks more about the broader notion of precompiled model compatibility. This change is directed at app developers whose apps may want to determine if a particular precompiled model (e.g. on a server somewhere) is compatible with the device where the application is running. There is functionality in `OrtEpFactory` for making this determination, which was exposed as a C API in microsoft#25841, and this change makes the API more broadly available in other languages. ### Testing and Validation Introduced new unit test cases across each language, and verified that the API was being called and returned the correct result for the default CPU EP. --------- Co-authored-by: Aditya Rastogi <adityar@ntdev.microsoft.com>
1 parent b3a074d commit c7997c2

File tree

8 files changed

+309
-0
lines changed

8 files changed

+309
-0
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,88 @@ public struct OrtApi
368368
public IntPtr EpDevice_Device;
369369
public IntPtr GetEpApi;
370370
public IntPtr GetTensorSizeInBytes;
371+
372+
public IntPtr AllocatorGetStats;
373+
374+
public IntPtr CreateMemoryInfo_V2;
375+
public IntPtr MemoryInfoGetDeviceMemType;
376+
public IntPtr MemoryInfoGetVendorId;
377+
378+
public IntPtr ValueInfo_GetValueProducer;
379+
public IntPtr ValueInfo_GetValueNumConsumers;
380+
public IntPtr ValueInfo_GetValueConsumers;
381+
public IntPtr ValueInfo_GetInitializerValue;
382+
public IntPtr ValueInfo_GetExternalInitializerInfo;
383+
public IntPtr ValueInfo_IsRequiredGraphInput;
384+
public IntPtr ValueInfo_IsOptionalGraphInput;
385+
public IntPtr ValueInfo_IsGraphOutput;
386+
public IntPtr ValueInfo_IsConstantInitializer;
387+
public IntPtr ValueInfo_IsFromOuterScope;
388+
public IntPtr Graph_GetName;
389+
public IntPtr Graph_GetModelPath;
390+
public IntPtr Graph_GetOnnxIRVersion;
391+
public IntPtr Graph_GetNumOperatorSets;
392+
public IntPtr Graph_GetOperatorSets;
393+
public IntPtr Graph_GetNumInputs;
394+
public IntPtr Graph_GetInputs;
395+
public IntPtr Graph_GetNumOutputs;
396+
public IntPtr Graph_GetOutputs;
397+
public IntPtr Graph_GetNumInitializers;
398+
public IntPtr Graph_GetInitializers;
399+
public IntPtr Graph_GetNumNodes;
400+
public IntPtr Graph_GetNodes;
401+
public IntPtr Graph_GetParentNode;
402+
public IntPtr Graph_GetGraphView;
403+
public IntPtr Node_GetId;
404+
public IntPtr Node_GetName;
405+
public IntPtr Node_GetOperatorType;
406+
public IntPtr Node_GetDomain;
407+
public IntPtr Node_GetSinceVersion;
408+
public IntPtr Node_GetNumInputs;
409+
public IntPtr Node_GetInputs;
410+
public IntPtr Node_GetNumOutputs;
411+
public IntPtr Node_GetOutputs;
412+
public IntPtr Node_GetNumImplicitInputs;
413+
public IntPtr Node_GetImplicitInputs;
414+
public IntPtr Node_GetNumAttributes;
415+
public IntPtr Node_GetAttributes;
416+
public IntPtr Node_GetAttributeByName;
417+
public IntPtr Node_GetTensorAttributeAsOrtValue;
418+
public IntPtr OpAttr_GetType;
419+
public IntPtr OpAttr_GetName;
420+
public IntPtr Node_GetNumSubgraphs;
421+
public IntPtr Node_GetSubgraphs;
422+
public IntPtr Node_GetGraph;
423+
public IntPtr Node_GetEpName;
424+
public IntPtr ReleaseExternalInitializerInfo;
425+
public IntPtr ExternalInitializerInfo_GetFilePath;
426+
public IntPtr ExternalInitializerInfo_GetFileOffset;
427+
public IntPtr ExternalInitializerInfo_GetByteSize;
428+
429+
public IntPtr GetRunConfigEntry;
430+
431+
public IntPtr EpDevice_MemoryInfo;
432+
433+
public IntPtr CreateSharedAllocator;
434+
public IntPtr GetSharedAllocator;
435+
public IntPtr ReleaseSharedAllocator;
436+
437+
public IntPtr GetTensorData;
438+
439+
public IntPtr GetSessionOptionsConfigEntries;
440+
441+
public IntPtr SessionGetMemoryInfoForInputs;
442+
public IntPtr SessionGetMemoryInfoForOutputs;
443+
public IntPtr SessionGetEpDeviceForInputs;
444+
445+
public IntPtr CreateSyncStreamForEpDevice;
446+
public IntPtr SyncStream_GetHandle;
447+
public IntPtr ReleaseSyncStream;
448+
449+
public IntPtr CopyTensors;
450+
451+
public IntPtr Graph_GetModelMetadata;
452+
public IntPtr GetModelCompatibilityForEpDevices;
371453
}
372454

373455
internal static class NativeMethods
@@ -704,6 +786,10 @@ static NativeMethods()
704786
(DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer(
705787
api_.SessionOptionsSetEpSelectionPolicyDelegate,
706788
typeof(DSessionOptionsSetEpSelectionPolicyDelegate));
789+
790+
OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer(
791+
api_.GetModelCompatibilityForEpDevices,
792+
typeof(DOrtGetModelCompatibilityForEpDevices));
707793
}
708794

709795
internal class NativeLib
@@ -2456,6 +2542,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
24562542

24572543
public static DOrtGetEpDevices OrtGetEpDevices;
24582544

2545+
/// <summary>
2546+
/// Validate compiled model compatibility for the provided EP devices.
2547+
/// </summary>
2548+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
2549+
public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices(
2550+
IntPtr[] /* const OrtEpDevice* const* */ ep_devices,
2551+
UIntPtr /* size_t */ num_ep_devices,
2552+
byte[] /* const char* */ compatibility_info,
2553+
out int /* OrtCompiledModelCompatibility */ out_status);
2554+
2555+
public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices;
2556+
24592557
/// <summary>
24602558
/// Add execution provider devices to the session options.
24612559
/// Priority is based on the order of the OrtEpDevice instances. Highest priority first.

csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77

88
namespace Microsoft.ML.OnnxRuntime
99
{
10+
/// <summary>
11+
/// Represents the compatibility status of a pre-compiled model with one or more execution provider devices.
12+
/// </summary>
13+
/// <remarks>
14+
/// This enum is used to determine whether a pre-compiled model can be used with specific execution providers
15+
/// and devices, or if recompilation is needed.
16+
/// </remarks>
17+
public enum OrtCompiledModelCompatibility
18+
{
19+
EP_NOT_APPLICABLE = 0,
20+
EP_SUPPORTED_OPTIMAL = 1,
21+
EP_SUPPORTED_PREFER_RECOMPILATION = 2,
22+
EP_UNSUPPORTED = 3,
23+
}
24+
1025
/// <summary>
1126
/// Delegate for logging function callback.
1227
/// Supply your function and register it with the environment to receive logging callbacks via
@@ -361,6 +376,31 @@ public string[] GetAvailableProviders()
361376
}
362377
}
363378

379+
/// <summary>
380+
/// Validate a compiled model's compatibility information for one or more EP devices.
381+
/// </summary>
382+
/// <param name="epDevices">The list of EP devices to validate against.</param>
383+
/// <param name="compatibilityInfo">The compatibility string from the precompiled model to validate.</param>
384+
/// <returns>OrtCompiledModelCompatibility enum value denoting the compatibility status</returns>
385+
public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
386+
IReadOnlyList<OrtEpDevice> epDevices, string compatibilityInfo)
387+
{
388+
if (epDevices == null || epDevices.Count == 0)
389+
throw new ArgumentException("epDevices must be non-empty", nameof(epDevices));
390+
391+
var devicePtrs = new IntPtr[epDevices.Count];
392+
for (int i = 0; i < epDevices.Count; ++i)
393+
{
394+
devicePtrs[i] = epDevices[i].Handle;
395+
}
396+
397+
var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo);
398+
NativeApiStatus.VerifySuccess(
399+
NativeMethods.OrtGetModelCompatibilityForEpDevices(
400+
devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status));
401+
return (OrtCompiledModelCompatibility)status;
402+
}
403+
364404

365405
/// <summary>
366406
/// Get/Set log level property of OrtEnv instance
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// not supported on mobile platforms
5+
#if !(ANDROID || IOS)
6+
7+
namespace Microsoft.ML.OnnxRuntime.Tests;
8+
9+
using System;
10+
using System.Linq;
11+
using Xunit;
12+
using System.Collections.Generic;
13+
14+
public class EpCompatibilityTests
15+
{
16+
private readonly OrtEnv ortEnvInstance = OrtEnv.Instance();
17+
18+
private IReadOnlyList<OrtEpDevice> GetDevices()
19+
{
20+
var epDevices = ortEnvInstance.GetEpDevices();
21+
Assert.NotNull(epDevices);
22+
Assert.NotEmpty(epDevices);
23+
return epDevices;
24+
}
25+
26+
[Fact]
27+
public void GetEpCompatibility_InvalidArgs()
28+
{
29+
Assert.Throws<ArgumentException>(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info"));
30+
Assert.Throws<ArgumentException>(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List<OrtEpDevice>(), "info"));
31+
}
32+
33+
[Fact]
34+
public void GetEpCompatibility_SingleDeviceCpuProvider()
35+
{
36+
var devices = GetDevices();
37+
var someInfo = "arbitrary-compat-string";
38+
39+
// Use CPU device
40+
var cpu = devices.First(d => d.EpName == "CPUExecutionProvider");
41+
Assert.NotNull(cpu);
42+
var selected = new List<OrtEpDevice> { cpu };
43+
var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo);
44+
45+
// CPU defaults to not applicable in this scenario
46+
Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status);
47+
}
48+
}
49+
#endif

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,16 @@ struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
10761076
ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
10771077
};
10781078

1079+
/** \brief Validate a compiled model's compatibility for one or more EP devices.
1080+
*
1081+
* Throws on error. Returns the resulting compatibility status.
1082+
* /// \param ep_devices The EP devices to check compatibility against.
1083+
* /// \param compatibility_info The compatibility string from the precompiled model to validate.
1084+
*/
1085+
OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
1086+
const std::vector<ConstEpDevice>& ep_devices,
1087+
const char* compatibility_info);
1088+
10791089
/** \brief The Env (Environment)
10801090
*
10811091
* The Env holds the logging state used by all other objects.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,26 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
859859
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
860860
}
861861

862+
inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
863+
const std::vector<ConstEpDevice>& ep_devices,
864+
const char* compatibility_info) {
865+
if (ep_devices.empty()) {
866+
ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT);
867+
}
868+
869+
std::vector<const OrtEpDevice*> ptrs;
870+
ptrs.reserve(ep_devices.size());
871+
for (const auto& d : ep_devices) ptrs.push_back(d);
872+
873+
OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
874+
ThrowOnError(GetApi().GetModelCompatibilityForEpDevices(
875+
reinterpret_cast<const OrtEpDevice* const*>(ptrs.data()),
876+
ptrs.size(),
877+
compatibility_info,
878+
&status));
879+
return status;
880+
}
881+
862882
inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
863883
OrtAllocator* allocator) {
864884
OrtLoraAdapter* p;

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,17 @@ void addGlobalMethods(py::module& m) {
15751575
R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc",
15761576
py::return_value_policy::reference);
15771577

1578+
m.def(
1579+
"get_model_compatibility_for_ep_devices",
1580+
[](const std::vector<const OrtEpDevice*>& ep_devices,
1581+
const std::string& compatibility_info) -> OrtCompiledModelCompatibility {
1582+
OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
1583+
Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices(
1584+
ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status));
1585+
return status;
1586+
},
1587+
R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");
1588+
15781589
#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
15791590
m.def(
15801591
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
@@ -1759,6 +1770,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
17591770
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED)
17601771
.value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT);
17611772

1773+
py::enum_<OrtCompiledModelCompatibility>(m, "OrtCompiledModelCompatibility")
1774+
.value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE)
1775+
.value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL)
1776+
.value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION)
1777+
.value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED);
1778+
17621779
py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
17631780
.value("INVALID", OrtInvalidAllocator)
17641781
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)

onnxruntime/test/framework/ep_compatibility_test.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
1616
#include "core/session/utils.h"
1717
#include "core/session/onnxruntime_c_api.h"
18+
#include "core/session/onnxruntime_cxx_api.h"
1819
#include "core/session/abi_session_options_impl.h"
1920
#include "core/framework/error_code_helper.h"
2021
#include "dummy_provider.h"
@@ -499,3 +500,31 @@ TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) {
499500

500501
api->ReleaseEnv(env);
501502
}
503+
504+
// -----------------------------
505+
// C++ API unit tests
506+
// -----------------------------
507+
508+
TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) {
509+
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"};
510+
auto devices = env.GetEpDevices();
511+
ASSERT_FALSE(devices.empty());
512+
513+
std::vector<Ort::ConstEpDevice> selected;
514+
for (const auto& d : devices) {
515+
if (std::string{d.EpName()} == "CPUExecutionProvider") {
516+
selected.push_back(d);
517+
break;
518+
}
519+
}
520+
521+
ASSERT_FALSE(selected.empty());
522+
523+
// Pick a status that the CPU EP would never return to ensure the value is set correctly.
524+
OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
525+
ASSERT_NO_FATAL_FAILURE({
526+
status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string");
527+
});
528+
529+
ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE);
530+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
import platform
6+
import sys
7+
import unittest
8+
9+
from onnxruntime.capi.onnxruntime_pybind11_state import (
10+
OrtCompiledModelCompatibility,
11+
get_ep_devices,
12+
get_model_compatibility_for_ep_devices,
13+
)
14+
15+
# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed.
16+
if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204
17+
os.add_dll_directory(os.getcwd())
18+
19+
20+
class TestEpCompatibility(unittest.TestCase):
21+
def test_invalid_args(self):
22+
# empty devices
23+
with self.assertRaises(RuntimeError):
24+
get_model_compatibility_for_ep_devices([], "info")
25+
# None compatibility info should raise TypeError before native call
26+
with self.assertRaises(TypeError):
27+
get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type]
28+
29+
def test_basic_smoke(self):
30+
devices = list(get_ep_devices())
31+
if not devices:
32+
self.skipTest("No EP devices available in this build")
33+
34+
# Always select CPUExecutionProvider; skip if not present.
35+
cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"]
36+
if not cpu_devices:
37+
self.skipTest("CPUExecutionProvider not available in this build")
38+
selected = [cpu_devices[0]]
39+
40+
# API requires all devices belong to the same EP; we pass only one.
41+
status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string")
42+
self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE)
43+
44+
45+
if __name__ == "__main__":
46+
unittest.main()

0 commit comments

Comments
 (0)