Skip to content

Commit 9d11ae2

Browse files
authored
Plugin EP data transfer and Stream support. (microsoft#25254)
### Description <!-- Describe your changes. --> Plugin EP data transfer and Stream support. Add the ability for a plugin EP to provide an IDataTransfer implementation and an OrtSyncStream implementation to do async data copy outside of an inference session. Example usage added for CUDA EP. Caveat: Support for providing the OrtSyncStream from the data copy to Session.Run will be a follow up PR. For the CUDA EP we can pass in the native cudaStream_t from the OrtSyncStream used for the data copy to the Run via CUDA EP provider options. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 8eea128 commit 9d11ae2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3145
-602
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,17 +498,19 @@ set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep")
498498
set (ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR "${TEST_SRC_DIR}/ep_graph")
499499

500500
set (onnxruntime_shared_lib_test_SRC
501-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
502-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc
503-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
501+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h
502+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc
504503
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc
505-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
504+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc
505+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
506506
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc
507+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
507508
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc
509+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
510+
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc
508511
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h
509512
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc
510-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h
511-
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc)
513+
)
512514

513515
if (NOT onnxruntime_MINIMAL_BUILD)
514516
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)

include/onnxruntime/core/session/environment.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/common/status.h"
1616
#include "core/framework/allocator.h"
1717
#include "core/framework/execution_provider.h"
18+
#include "core/framework/data_transfer_manager.h"
1819
#include "core/platform/device_discovery.h"
1920
#include "core/platform/threadpool.h"
2021

@@ -140,6 +141,10 @@ class Environment {
140141
OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type,
141142
const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator);
142143
Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type);
144+
145+
const DataTransferManager& GetDataTransferManager() const {
146+
return data_transfer_mgr_;
147+
}
143148
#endif // !defined(ORT_MINIMAL_BUILD)
144149

145150
// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator
@@ -185,6 +190,23 @@ class Environment {
185190

186191
using OrtAllocatorUniquePtr = std::unique_ptr<OrtAllocator, std::function<void(OrtAllocator*)>>;
187192

193+
// if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with
194+
// OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator.
195+
// we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in
196+
// shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an
197+
// OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently.
198+
// we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_.
199+
//
200+
// TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator
201+
// or an OrtAllocator instance to reduce the indirection a little.
202+
// with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the
203+
// IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_.
204+
//
205+
// Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena
206+
// implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a
207+
// cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source.
208+
std::unordered_map<const OrtMemoryInfo*, std::unique_ptr<OrtAllocatorImplWrappingIAllocator>> arena_ort_allocators_;
209+
188210
#if !defined(ORT_MINIMAL_BUILD)
189211
// register EPs that are built into the ORT binary so they can take part in AutoEP selection
190212
// added to ep_libraries
@@ -207,7 +229,9 @@ class Environment {
207229

208230
std::unique_ptr<EpLibrary> library;
209231
std::vector<std::unique_ptr<OrtEpDevice>> execution_devices;
210-
std::vector<EpFactoryInternal*> internal_factories; // factories that can create IExecutionProvider instances
232+
std::vector<OrtEpFactory*> factories;
233+
std::vector<EpFactoryInternal*> internal_factories; // factories that can create IExecutionProvider instances
234+
std::vector<plugin_ep::DataTransfer*> data_transfers; // data transfer instances for this EP.
211235

212236
private:
213237
EpInfo() = default;
@@ -223,6 +247,9 @@ class Environment {
223247

224248
// lookup set for internal EPs so we can create an IExecutionProvider directly
225249
std::unordered_set<EpFactoryInternal*> internal_ep_factories_;
250+
251+
DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances
252+
226253
#endif // !defined(ORT_MINIMAL_BUILD)
227254
};
228255

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ extern "C" {
6464
#define _Outptr_result_maybenull_
6565
#define _Outptr_result_maybenull_z_
6666
#define _In_reads_(X)
67+
#define _In_reads_opt_
6768
#define _Inout_updates_(X)
6869
#define _Out_writes_(X)
6970
#define _Out_writes_opt_(X)
@@ -322,6 +323,7 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions);
322323
ORT_RUNTIME_CLASS(HardwareDevice);
323324
ORT_RUNTIME_CLASS(EpDevice);
324325
ORT_RUNTIME_CLASS(KeyValuePairs);
326+
ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream.
325327

326328
#ifdef _MSC_VER
327329
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -426,10 +428,14 @@ typedef enum OrtAllocatorType {
426428
*/
427429
// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc
428430
typedef enum OrtMemType {
429-
OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider
430-
OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED
431-
OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED
432-
OrtMemTypeDefault = 0, ///< The default allocator for execution provider
431+
/// Any CPU memory used by non-CPU execution provider
432+
OrtMemTypeCPUInput = -2,
433+
/// CPU accessible memory outputted by non-CPU execution provider, i.e. HOST_ACCESSIBLE
434+
OrtMemTypeCPUOutput = -1,
435+
/// CPU accessible memory allocated by non-CPU execution provider, i.e. HOST_ACCESSIBLE
436+
OrtMemTypeCPU = OrtMemTypeCPUOutput,
437+
/// The default allocator for execution provider
438+
OrtMemTypeDefault = 0,
433439
} OrtMemType;
434440

435441
/** \brief This matches OrtDevice::MemoryType values */
@@ -1743,7 +1749,7 @@ struct OrtApi {
17431749
*/
17441750
ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out);
17451751

1746-
/** \brief Get the id from ::OrtMemoryInfo
1752+
/** \brief Get the device id from ::OrtMemoryInfo
17471753
*/
17481754
ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out);
17491755

@@ -5384,10 +5390,32 @@ struct OrtApi {
53845390
* \since Version 1.23
53855391
*/
53865392
ORT_API2_STATUS(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type,
5387-
_In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
5393+
_In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
53885394
_In_ size_t alignment, enum OrtAllocatorType allocator_type,
53895395
_Outptr_ OrtMemoryInfo** out);
53905396

5397+
/** \brief Get the device memory type from ::OrtMemoryInfo
5398+
*
5399+
* \param[in] ptr The OrtMemoryInfo instance to query.
5400+
* \param[out] out The device memory type.
5401+
*
5402+
* \snippet{doc} snippets.dox OrtStatus Return Value
5403+
*
5404+
* \since Version 1.23
5405+
*/
5406+
ORT_API2_STATUS(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out);
5407+
5408+
/** \brief Get the vendor id from ::OrtMemoryInfo
5409+
*
5410+
* \param[in] ptr The OrtMemoryInfo instance to query.
5411+
* \param[out] out The vendor id.
5412+
*
5413+
* \snippet{doc} snippets.dox OrtStatus Return Value
5414+
*
5415+
* \since Version 1.23
5416+
*/
5417+
ORT_API2_STATUS(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out);
5418+
53915419
/// \name OrtValueInfo
53925420
/// @{
53935421

@@ -6068,11 +6096,14 @@ struct OrtApi {
60686096
/** \brief Get the OrtMemoryInfo for the device.
60696097
*
60706098
* \param[in] ep_device The OrtEpDevice instance to query.
6071-
* \return A pointer to the OrtMemoryInfo for the device.
6099+
* \param[in] memory_type The memory type to return.
6100+
* \return A pointer to the OrtMemoryInfo for the device. This may be nullptr if not set.
6101+
* If memory_type is OrtDeviceMemoryType_DEFAULT and nullptr is returned the EP uses CPU memory.
60726102
*
60736103
* \since Version 1.23
60746104
*/
6075-
ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device);
6105+
ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device,
6106+
_In_ OrtDeviceMemoryType memory_type);
60766107

60776108
/** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv.
60786109
*
@@ -6164,6 +6195,141 @@ struct OrtApi {
61646195
* \since Version 1.23.
61656196
*/
61666197
ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out);
6198+
6199+
/** \brief Get the OrtMemoryInfo for each input of the session.
6200+
*
6201+
* The memory info can be used to determine where the input tensors are required.
6202+
*
6203+
* The session must be fully initialized before calling this function as the input locations are not known until
6204+
* this has occurred.
6205+
*
6206+
* \param[in] session The OrtSession instance.
6207+
* \param[out] inputs_memory_info Pre-allocated array of size `num_inputs` that will be filled with the
6208+
* OrtMemoryInfo* value for each input.
6209+
* The order is the same as returned by SessionGetInputName.
6210+
* \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount.
6211+
*
6212+
* \snippet{doc} snippets.dox OrtStatus Return Value
6213+
*
6214+
* \since Version 1.23
6215+
*/
6216+
ORT_API2_STATUS(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session,
6217+
_Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info,
6218+
_In_ size_t num_inputs);
6219+
6220+
/** \brief Get the OrtMemoryInfo for each output of the session.
6221+
*
6222+
* The memory info can be used to determine the device the output tensors are produced on.
6223+
* The user can pre-allocate an OrtValue using this information or use IOBinding to keep the data on the device.
6224+
* ORT will copy the output to CPU otherwise.
6225+
*
6226+
* The session must be fully initialized before calling this function as the output locations are not known until
6227+
* this has occurred.
6228+
*
6229+
* \param[in] session The OrtSession instance.
6230+
* \param[out] outputs_memory_info Pre-allocated array of size `num_outputs` that will be filled with
6231+
* OrtMemoryInfo* values for each output.
6232+
* The order is the same as returned by SessionGetOutputName.
6233+
* \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount.
6234+
*
6235+
* \snippet{doc} snippets.dox OrtStatus Return Value
6236+
*
6237+
* \since Version 1.23
6238+
*/
6239+
ORT_API2_STATUS(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session,
6240+
_Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info,
6241+
_In_ size_t num_outputs);
6242+
6243+
/** \brief Get the OrtEpDevice (if available) for each input of the session.
6244+
*
6245+
* An OrtEpDevice will be available if auto EP selection is enabled by calling
6246+
* SessionOptionsSetEpSelectionPolicy or SessionOptionsSetEpSelectionPolicyDelegate,
6247+
* or if the OrtEpDevice was manually added to the session using SessionOptionsAppendExecutionProvider_V2.
6248+
*
6249+
* If an OrtEpDevice is not available for the input a nullptr is returned.
6250+
*
6251+
* The returned OrtEpDevice can be used to create an OrtSyncStream via CreateSyncStreamForEpDevice to asynchronously
6252+
* provide input to the inference session Run.
6253+
*
6254+
* The session must be fully initialized before calling this function as the assigned EPs are not known until
6255+
* this has occurred.
6256+
*
6257+
* \param[in] session The OrtSession instance.
6258+
* \param[out] inputs_ep_devices Pre-allocated array of size `num_inputs` that will be filled with
6259+
* OrtEpDevice* values for each input.
6260+
* The order is the same as returned by SessionGetInputName.
6261+
* \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount.
6262+
*
6263+
* \snippet{doc} snippets.dox OrtStatus Return Value
6264+
*
6265+
* \since Version 1.23
6266+
*/
6267+
ORT_API2_STATUS(SessionGetEpDeviceForInputs, _In_ const OrtSession* session,
6268+
_Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices,
6269+
_In_ size_t num_inputs);
6270+
6271+
/** \brief Create an OrtSyncStream for the given OrtEpDevice.
6272+
*
6273+
* The OrtSyncStream can be used to enable asynchronous operations.
6274+
* e.g. async usage of CopyTensors to provide input to an OrtSession Run call.
6275+
*
6276+
* An error code of ORT_NOT_IMPLEMENTED will be returned if the EP does not support OrtSyncStream.
6277+
*
6278+
* \param[in] ep_device The OrtEpDevice instance to create the sync stream for.
6279+
* \param[in] stream_options Options for OrtSyncStream creation. May be nullptr.
6280+
* \param[out] stream Output parameter set to the created OrtSyncStream instance.
6281+
*
6282+
* \snippet{doc} snippets.dox OrtStatus Return Value
6283+
*
6284+
* \since Version 1.23
6285+
*/
6286+
ORT_API2_STATUS(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device,
6287+
_In_opt_ const OrtKeyValuePairs* stream_options,
6288+
_Outptr_ OrtSyncStream** stream);
6289+
6290+
/** \brief Get the native handle of the sync stream.
6291+
*
6292+
* This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams.
6293+
*
6294+
* \param[in] stream The OrtSyncStream instance to get the handle from.
6295+
*
6296+
* \returns The native handle of the stream.
6297+
*
6298+
* \since Version 1.23
6299+
*/
6300+
ORT_API_T(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream);
6301+
6302+
ORT_CLASS_RELEASE(SyncStream);
6303+
6304+
/** \brief Copy OrtValue instances containing Tensors between devices.
6305+
*
6306+
* The overall copy must be between a single source device and a single destination device. i.e.
6307+
* - all src_tensors must have matching OrtMemoryInfo,
6308+
* - all dst_tensors must have matching OrtMemoryInfo.
6309+
*
6310+
* OrtValue instances can be created by:
6311+
* - Use GetSharedAllocator to get the shared allocator for the OrtMemoryInfo if you need to allocate memory
6312+
* on the device.
6313+
* - Use CreateTensorAsOrtValue, CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue
6314+
* to create an OrtValue containing a tensor depending on whether you have existing data or not, and whether
6315+
* you want ORT to free the existing data once it is done with the OrtValue.
6316+
*
6317+
* \param[in] env The OrtEnv instance to use. The data transfer implementation is provided by an execution provider
6318+
* that is registered in this OrtEnv.
6319+
* \param[in] src_tensors Array of OrtValue instances containing the source tensors to copy.
6320+
* \param[in] dst_tensors Array of OrtValue instances to copy the source tensors to.
6321+
* \param[in] stream Optional OrtSyncStream that can be used to perform the copy asynchronously. May be nullptr.
6322+
* \param[in] num_tensors The number of tensors to copy. The size of `src_tensors` and `dst_tensors` must match.
6323+
*
6324+
* \snippet{doc} snippets.dox OrtStatus Return Value
6325+
*
6326+
* \since Version 1.23
6327+
*/
6328+
ORT_API2_STATUS(CopyTensors, _In_ const OrtEnv* env,
6329+
_In_reads_(num_tensors) const OrtValue* const* src_tensors,
6330+
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
6331+
_In_opt_ OrtSyncStream* stream,
6332+
_In_ size_t num_tensors);
61676333
};
61686334

61696335
/*

0 commit comments

Comments
 (0)