Skip to content

Commit 0ccecf7

Browse files
[EP ABI] Infer OrtDevice for plugin EP from registered OrtMemoryInfo (microsoft#25308)
### Description - Infer `OrtDevice` for a plugin EP from the registered `OrtMemoryInfo` for device memory. - Fix potential `nullptr` dereference when a `PluginExecutionProvider` tries to log a message without a valid logger. Now, constructing a `PluginExecutionProvider` requires passing a valid logger. ### Motivation and Context Address a `TODO` to properly set the `OrtDevice` for a `PluginExecutionProvider` instance.
1 parent 7c18d89 commit 0ccecf7

File tree

5 files changed

+189
-13
lines changed

5 files changed

+189
-13
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ class IExecutionProvider {
7979
: default_device_(device), type_{type} {
8080
}
8181

82+
IExecutionProvider(const std::string& type, OrtDevice device, const logging::Logger& logger)
83+
: default_device_(device), type_{type}, logger_{&logger} {
84+
}
85+
8286
/*
8387
default device for this ExecutionProvider
8488
*/

onnxruntime/core/session/ep_plugin_provider_interfaces.cc

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,9 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
5353
ORT_THROW("Error creating execution provider: ", status.ToString());
5454
}
5555

56-
auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
57-
session_options, ep_factory_, devices_);
58-
ep_wrapper->SetLogger(session_logger.ToInternal());
59-
60-
return ep_wrapper;
56+
return std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
57+
session_options, ep_factory_, devices_,
58+
*session_logger.ToInternal());
6159
}
6260

6361
/// <summary>
@@ -86,10 +84,43 @@ struct PluginEpMetaDefNameFunctor {
8684
// PluginExecutionProvider
8785
//
8886

87+
static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_devices) {
88+
// Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU.
89+
// If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all.
90+
91+
ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices.
92+
93+
const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info;
94+
95+
// Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos
96+
bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(),
97+
[mem_a = device_memory_info](const OrtEpDevice* ep_device) {
98+
const OrtMemoryInfo* mem_b = ep_device->device_memory_info;
99+
100+
if (mem_a == mem_b) {
101+
return true; // Point to the same OrtMemoryInfo instance.
102+
}
103+
104+
if (mem_a == nullptr || mem_b == nullptr) {
105+
return false; // One is nullptr and the other is not.
106+
}
107+
108+
// Both non-null but point to different instances. Use operator==.
109+
return *mem_a == *mem_b;
110+
});
111+
if (!all_match) {
112+
ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name,
113+
"': expected all OrtEpDevice instances to use the same device_memory_info.");
114+
}
115+
116+
return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
117+
}
118+
89119
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options,
90120
OrtEpFactory& ep_factory,
91-
gsl::span<const OrtEpDevice* const> ep_devices)
92-
: IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
121+
gsl::span<const OrtEpDevice* const> ep_devices,
122+
const logging::Logger& logger)
123+
: IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), logger),
93124
ort_ep_(std::move(ep)),
94125
ep_factory_(ep_factory),
95126
ep_devices_(ep_devices.begin(), ep_devices.end()) {

onnxruntime/core/session/ep_plugin_provider_interfaces.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class PluginExecutionProvider : public IExecutionProvider {
6565

6666
public:
6767
explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory,
68-
gsl::span<const OrtEpDevice* const> ep_devices);
68+
gsl::span<const OrtEpDevice* const> ep_devices, const logging::Logger& logger);
6969
~PluginExecutionProvider();
7070

7171
std::vector<std::unique_ptr<ComputeCapability>>

onnxruntime/core/session/provider_policy_context.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "core/framework/error_code_helper.h"
1212
#include "core/session/abi_devices.h"
13+
#include "core/session/abi_logger.h"
1314
#include "core/session/ep_factory_internal.h"
1415
#include "core/session/ep_plugin_provider_interfaces.h"
1516
#include "core/session/inference_session.h"
@@ -355,7 +356,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or
355356
info.ep_factory->CreateEp(info.ep_factory, info.hardware_devices.data(), info.ep_metadata.data(),
356357
info.hardware_devices.size(), &options, &logger, &api_ep)));
357358
ep = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options,
358-
*info.ep_factory, info.devices);
359+
*info.ep_factory, info.devices, *logger.ToInternal());
359360
}
360361

361362
return Status::OK();

onnxruntime/test/framework/ep_plugin_provider_test.cc

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/session/abi_devices.h"
1010
#include "core/session/onnxruntime_cxx_api.h"
1111
#include "test/util/include/asserts.h"
12+
#include "test/util/include/test_environment.h"
1213

1314
namespace onnxruntime::test {
1415

@@ -56,24 +57,58 @@ struct TestOrtEpFactory : ::OrtEpFactory {
5657

5758
static TestOrtEpFactory g_test_ort_ep_factory{};
5859

60+
std::unique_ptr<OrtHardwareDevice> MakeTestOrtHardwareDevice(OrtHardwareDeviceType type) {
61+
auto hw_device = std::make_unique<OrtHardwareDevice>();
62+
hw_device->type = type;
63+
hw_device->vendor_id = 0xBE57;
64+
hw_device->device_id = 0;
65+
hw_device->vendor = "Contoso";
66+
return hw_device;
67+
}
68+
69+
std::unique_ptr<OrtEpDevice> MakeTestOrtEpDevice(const OrtHardwareDevice* hardware_device,
70+
const OrtMemoryInfo* device_memory_info = nullptr,
71+
const OrtMemoryInfo* host_accessible_memory_info = nullptr) {
72+
auto ep_device = std::make_unique<OrtEpDevice>();
73+
ep_device->ep_name = "TestOrtEp";
74+
ep_device->ep_vendor = "Contoso";
75+
ep_device->device = hardware_device;
76+
ep_device->ep_factory = &g_test_ort_ep_factory;
77+
ep_device->device_memory_info = device_memory_info;
78+
ep_device->host_accessible_memory_info = host_accessible_memory_info;
79+
return ep_device;
80+
}
81+
82+
OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::MemoryType memory_type) {
83+
return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16);
84+
}
85+
5986
struct MakeTestOrtEpResult {
6087
std::unique_ptr<IExecutionProvider> ep; // the IExecutionProvider wrapping the TestOrtEp
6188
gsl::not_null<TestOrtEp*> ort_ep; // the wrapped TestOrtEp, owned by `ep`
6289
};
6390

6491
// Creates an IExecutionProvider that wraps a TestOrtEp.
6592
// The TestOrtEp is also exposed so that tests can manipulate its function pointers directly.
66-
MakeTestOrtEpResult MakeTestOrtEp() {
93+
MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {}) {
94+
// Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices.
95+
static std::unique_ptr<OrtHardwareDevice> ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU);
96+
static std::unique_ptr<OrtEpDevice> ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get());
97+
6798
auto ort_ep_raw = std::make_unique<TestOrtEp>().release();
6899
auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory});
69100
auto ort_session_options = Ort::SessionOptions{};
70-
auto ort_ep_device = OrtEpDevice{};
71-
std::vector<const OrtEpDevice*> ep_devices{&ort_ep_device};
72101

102+
if (ep_devices.empty()) {
103+
ep_devices.push_back(ort_ep_device.get());
104+
}
105+
106+
auto& logging_manager = DefaultLoggingManager();
73107
auto ep = std::make_unique<PluginExecutionProvider>(std::move(ort_ep),
74108
*static_cast<const OrtSessionOptions*>(ort_session_options),
75109
g_test_ort_ep_factory,
76-
ep_devices);
110+
ep_devices,
111+
logging_manager.DefaultLogger());
77112

78113
auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw};
79114
return result;
@@ -177,4 +212,109 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) {
177212
#endif // !defined(ORT_NO_EXCEPTIONS)
178213
}
179214

215+
TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
216+
// 1 OrtEpDevice without a device_memory_info.
217+
// PluginExecutionProvider should decide to use a default OrtDevice.
218+
{
219+
auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU);
220+
auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get());
221+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()};
222+
223+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
224+
ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice());
225+
}
226+
227+
// 1 OrtEpDevice with a device_memory_info.
228+
// PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info.
229+
{
230+
auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT);
231+
auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
232+
ort_device, OrtMemTypeDefault);
233+
234+
auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
235+
auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(),
236+
/*device_memory_info*/ ort_memory_info.get());
237+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()};
238+
239+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
240+
ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
241+
}
242+
243+
// 2 OrtEpDevice instances with the same device_memory_info.
244+
// PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info.
245+
{
246+
auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT);
247+
auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
248+
ort_device, OrtMemTypeDefault);
249+
250+
auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
251+
auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
252+
auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info.get());
253+
auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info.get());
254+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
255+
256+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
257+
ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
258+
}
259+
260+
// 2 OrtEpDevice instances with the different (but equivalent) device_memory_info pointers.
261+
// PluginExecutionProvider should decide to use a OrtDevice that is equal to the devices used by both
262+
// device_memory_info pointers.
263+
{
264+
auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT);
265+
auto ort_memory_info_0 = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
266+
ort_device, OrtMemTypeDefault);
267+
auto ort_memory_info_1 = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator,
268+
ort_device, OrtMemTypeDefault);
269+
270+
auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
271+
auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
272+
auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_0.get());
273+
auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_1.get());
274+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
275+
276+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
277+
ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device);
278+
}
279+
280+
// 1 OrtEpDevice with only a host_accessible_memory_info.
281+
// PluginExecutionProvider should decide to use a default OrtDevice (cpu).
282+
{
283+
auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE);
284+
auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
285+
ort_device, OrtMemTypeDefault);
286+
287+
auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
288+
auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(),
289+
/*device_memory_info*/ nullptr,
290+
/*host_accessible_memory_info*/ ort_memory_info.get());
291+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()};
292+
293+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices);
294+
ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice());
295+
}
296+
297+
#if !defined(ORT_NO_EXCEPTIONS)
298+
// 2 OrtEpDevice instances with DIFFERENT device_memory_info instances.
299+
// Should throw an exception on construction of PluginExecutionProvider.
300+
{
301+
auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT);
302+
auto ort_memory_info_gpu = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator,
303+
ort_device_gpu, OrtMemTypeDefault);
304+
305+
auto ort_device_npu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT);
306+
auto ort_memory_info_npu = std::make_unique<OrtMemoryInfo>("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator,
307+
ort_device_npu, OrtMemTypeDefault);
308+
309+
auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU);
310+
auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU);
311+
auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_gpu.get());
312+
auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get());
313+
std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()};
314+
315+
ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException);
316+
}
317+
#endif // !defined(ORT_NO_EXCEPTIONS)
318+
}
319+
180320
} // namespace onnxruntime::test

0 commit comments

Comments
 (0)