|
9 | 9 | #include "core/session/abi_devices.h" |
10 | 10 | #include "core/session/onnxruntime_cxx_api.h" |
11 | 11 | #include "test/util/include/asserts.h" |
| 12 | +#include "test/util/include/test_environment.h" |
12 | 13 |
|
13 | 14 | namespace onnxruntime::test { |
14 | 15 |
|
@@ -56,24 +57,58 @@ struct TestOrtEpFactory : ::OrtEpFactory { |
56 | 57 |
|
57 | 58 | static TestOrtEpFactory g_test_ort_ep_factory{}; |
58 | 59 |
|
| 60 | +std::unique_ptr<OrtHardwareDevice> MakeTestOrtHardwareDevice(OrtHardwareDeviceType type) { |
| 61 | + auto hw_device = std::make_unique<OrtHardwareDevice>(); |
| 62 | + hw_device->type = type; |
| 63 | + hw_device->vendor_id = 0xBE57; |
| 64 | + hw_device->device_id = 0; |
| 65 | + hw_device->vendor = "Contoso"; |
| 66 | + return hw_device; |
| 67 | +} |
| 68 | + |
| 69 | +std::unique_ptr<OrtEpDevice> MakeTestOrtEpDevice(const OrtHardwareDevice* hardware_device, |
| 70 | + const OrtMemoryInfo* device_memory_info = nullptr, |
| 71 | + const OrtMemoryInfo* host_accessible_memory_info = nullptr) { |
| 72 | + auto ep_device = std::make_unique<OrtEpDevice>(); |
| 73 | + ep_device->ep_name = "TestOrtEp"; |
| 74 | + ep_device->ep_vendor = "Contoso"; |
| 75 | + ep_device->device = hardware_device; |
| 76 | + ep_device->ep_factory = &g_test_ort_ep_factory; |
| 77 | + ep_device->device_memory_info = device_memory_info; |
| 78 | + ep_device->host_accessible_memory_info = host_accessible_memory_info; |
| 79 | + return ep_device; |
| 80 | +} |
| 81 | + |
| 82 | +OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::MemoryType memory_type) { |
| 83 | + return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16); |
| 84 | +} |
| 85 | + |
59 | 86 | struct MakeTestOrtEpResult { |
60 | 87 | std::unique_ptr<IExecutionProvider> ep; // the IExecutionProvider wrapping the TestOrtEp |
61 | 88 | gsl::not_null<TestOrtEp*> ort_ep; // the wrapped TestOrtEp, owned by `ep` |
62 | 89 | }; |
63 | 90 |
|
64 | 91 | // Creates an IExecutionProvider that wraps a TestOrtEp. |
65 | 92 | // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. |
66 | | -MakeTestOrtEpResult MakeTestOrtEp() { |
| 93 | +MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {}) { |
| 94 | + // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. |
| 95 | + static std::unique_ptr<OrtHardwareDevice> ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); |
| 96 | + static std::unique_ptr<OrtEpDevice> ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); |
| 97 | + |
67 | 98 | auto ort_ep_raw = std::make_unique<TestOrtEp>().release(); |
68 | 99 | auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); |
69 | 100 | auto ort_session_options = Ort::SessionOptions{}; |
70 | | - auto ort_ep_device = OrtEpDevice{}; |
71 | | - std::vector<const OrtEpDevice*> ep_devices{&ort_ep_device}; |
72 | 101 |
|
| 102 | + if (ep_devices.empty()) { |
| 103 | + ep_devices.push_back(ort_ep_device.get()); |
| 104 | + } |
| 105 | + |
| 106 | + auto& logging_manager = DefaultLoggingManager(); |
73 | 107 | auto ep = std::make_unique<PluginExecutionProvider>(std::move(ort_ep), |
74 | 108 | *static_cast<const OrtSessionOptions*>(ort_session_options), |
75 | 109 | g_test_ort_ep_factory, |
76 | | - ep_devices); |
| 110 | + ep_devices, |
| 111 | + logging_manager.DefaultLogger()); |
77 | 112 |
|
78 | 113 | auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; |
79 | 114 | return result; |
@@ -177,4 +212,109 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { |
177 | 212 | #endif // !defined(ORT_NO_EXCEPTIONS) |
178 | 213 | } |
179 | 214 |
|
| 215 | +TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { |
| 216 | + // 1 OrtEpDevice without a device_memory_info. |
| 217 | + // PluginExecutionProvider should decide to use a default OrtDevice. |
| 218 | + { |
| 219 | + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); |
| 220 | + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); |
| 221 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()}; |
| 222 | + |
| 223 | + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); |
| 224 | + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); |
| 225 | + } |
| 226 | + |
| 227 | + // 1 OrtEpDevice with a device_memory_info. |
| 228 | + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. |
| 229 | + { |
| 230 | + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); |
| 231 | + auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, |
| 232 | + ort_device, OrtMemTypeDefault); |
| 233 | + |
| 234 | + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); |
| 235 | + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), |
| 236 | + /*device_memory_info*/ ort_memory_info.get()); |
| 237 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()}; |
| 238 | + |
| 239 | + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); |
| 240 | + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); |
| 241 | + } |
| 242 | + |
| 243 | + // 2 OrtEpDevice instances with the same device_memory_info. |
| 244 | + // PluginExecutionProvider should decide to use the OrtDevice from the device_memory_info. |
| 245 | + { |
| 246 | + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); |
| 247 | + auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, |
| 248 | + ort_device, OrtMemTypeDefault); |
| 249 | + |
| 250 | + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); |
| 251 | + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); |
| 252 | + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info.get()); |
| 253 | + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info.get()); |
| 254 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; |
| 255 | + |
| 256 | + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); |
| 257 | + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); |
| 258 | + } |
| 259 | + |
| 260 | + // 2 OrtEpDevice instances with the different (but equivalent) device_memory_info pointers. |
| 261 | + // PluginExecutionProvider should decide to use a OrtDevice that is equal to the devices used by both |
| 262 | + // device_memory_info pointers. |
| 263 | + { |
| 264 | + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT); |
| 265 | + auto ort_memory_info_0 = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, |
| 266 | + ort_device, OrtMemTypeDefault); |
| 267 | + auto ort_memory_info_1 = std::make_unique<OrtMemoryInfo>("TestOrtEp CPU", OrtAllocatorType::OrtDeviceAllocator, |
| 268 | + ort_device, OrtMemTypeDefault); |
| 269 | + |
| 270 | + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); |
| 271 | + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); |
| 272 | + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_0.get()); |
| 273 | + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_1.get()); |
| 274 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; |
| 275 | + |
| 276 | + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); |
| 277 | + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); |
| 278 | + } |
| 279 | + |
| 280 | + // 1 OrtEpDevice with only a host_accessible_memory_info. |
| 281 | + // PluginExecutionProvider should decide to use a default OrtDevice (cpu). |
| 282 | + { |
| 283 | + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); |
| 284 | + auto ort_memory_info = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, |
| 285 | + ort_device, OrtMemTypeDefault); |
| 286 | + |
| 287 | + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); |
| 288 | + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), |
| 289 | + /*device_memory_info*/ nullptr, |
| 290 | + /*host_accessible_memory_info*/ ort_memory_info.get()); |
| 291 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device.get()}; |
| 292 | + |
| 293 | + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); |
| 294 | + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); |
| 295 | + } |
| 296 | + |
| 297 | +#if !defined(ORT_NO_EXCEPTIONS) |
| 298 | + // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances. |
| 299 | + // Should throw an exception on construction of PluginExecutionProvider. |
| 300 | + { |
| 301 | + auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); |
| 302 | + auto ort_memory_info_gpu = std::make_unique<OrtMemoryInfo>("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, |
| 303 | + ort_device_gpu, OrtMemTypeDefault); |
| 304 | + |
| 305 | + auto ort_device_npu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); |
| 306 | + auto ort_memory_info_npu = std::make_unique<OrtMemoryInfo>("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, |
| 307 | + ort_device_npu, OrtMemTypeDefault); |
| 308 | + |
| 309 | + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); |
| 310 | + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); |
| 311 | + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), ort_memory_info_gpu.get()); |
| 312 | + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get()); |
| 313 | + std::vector<const OrtEpDevice*> ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; |
| 314 | + |
| 315 | + ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException); |
| 316 | + } |
| 317 | +#endif // !defined(ORT_NO_EXCEPTIONS) |
| 318 | +} |
| 319 | + |
180 | 320 | } // namespace onnxruntime::test |
0 commit comments