88#include < unordered_set>
99#include < utility>
1010#include < vector>
11+ #include " core/framework/abi_pointer_array.h"
1112#include " core/framework/compute_capability.h"
1213#include " core/framework/error_code_helper.h"
1314#include " core/framework/model_metadef_id_generator.h"
1415#include " core/graph/ep_api_types.h"
15- #include " core/session/ort_apis .h"
16+ #include " core/graph/model_editor_api_types .h"
1617#include " core/session/abi_devices.h"
1718#include " core/session/abi_ep_types.h"
1819#include " core/session/abi_logger.h"
20+ #include " core/session/abi_session_options_impl.h"
1921#include " core/session/allocator_adapters.h"
22+ #include " core/session/ort_apis.h"
2023#include " core/providers/partitioning_utils.h"
2124
2225namespace onnxruntime {
@@ -48,7 +51,8 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
4851 ORT_THROW (" Error creating execution provider: " , status.ToString ());
4952 }
5053
51- auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp (ort_ep, OrtEpDeleter (ep_factory_)));
54+ auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp (ort_ep, OrtEpDeleter (ep_factory_)),
55+ session_options);
5256 ep_wrapper->SetLogger (session_logger.ToInternal ());
5357
5458 return ep_wrapper;
@@ -80,9 +84,10 @@ struct PluginEpMetaDefNameFunctor {
8084// PluginExecutionProvider
8185//
8286
83- PluginExecutionProvider::PluginExecutionProvider (UniqueOrtEp ep)
87+ PluginExecutionProvider::PluginExecutionProvider (UniqueOrtEp ep, const OrtSessionOptions& session_options )
8488 : IExecutionProvider(ep->GetName (ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
8589 ort_ep_(std::move(ep)) {
90+ generate_ep_ctx_model_ = session_options.value .GetEpContextGenerationOptions ().enable ;
8691}
8792
8893PluginExecutionProvider::~PluginExecutionProvider () {
@@ -185,6 +190,87 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
185190 return Status::OK ();
186191}
187192
193+ // / <summary>
194+ // / Converts the EPContext nodes provided by the plugin EP (OrtNode instances) to onnxruntime::Node instances.
195+ // / Note that the EP plugin uses the model editor API to create the OrtNode instances.
196+ // / </summary>
197+ // / <param name="ep_name">Name of the plugin EP.</param>
198+ // / <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
199+ // / <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
200+ // / <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
201+ // / <returns>A status indicating success or an error.</returns>
202+ static Status ConvertEpContextNodes (const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
203+ /* out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
204+ /* out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
205+ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
206+ if (plugin_ep_context_nodes.empty ()) {
207+ return Status::OK (); // No EPContext nodes.
208+ }
209+
210+ std::vector<std::unique_ptr<Node>> ep_context_nodes_holder;
211+ std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
212+
213+ ep_context_nodes_holder.reserve (plugin_ep_context_nodes.size ());
214+
215+ for (const OrtNode* ort_node : plugin_ep_context_nodes) {
216+ ORT_RETURN_IF_NOT (ort_node != nullptr , ep_name, " : OrtEp::Compile() returned a NULL EPContext node." );
217+
218+ const ModelEditorNode* editor_node = ModelEditorNode::ToInternal (ort_node);
219+ ORT_RETURN_IF_NOT (editor_node != nullptr , ep_name, " : OrtEp::Compile() returned OrtNode objects " ,
220+ " that were not created with OrtModelEditorApi." );
221+
222+ // Create NodeArg for each input/output.
223+ std::vector<NodeArg*> input_node_args;
224+ std::vector<NodeArg*> output_node_args;
225+
226+ input_node_args.reserve (editor_node->input_names .size ());
227+ output_node_args.reserve (editor_node->output_names .size ());
228+
229+ for (const std::string& input_name : editor_node->input_names ) {
230+ auto node_arg = std::make_unique<NodeArg>(input_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
231+ input_node_args.push_back (node_arg.get ());
232+ ep_context_node_args_holder.push_back (std::move (node_arg));
233+ }
234+
235+ for (const std::string& output_name : editor_node->output_names ) {
236+ auto node_arg = std::make_unique<NodeArg>(output_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
237+ output_node_args.push_back (node_arg.get ());
238+ ep_context_node_args_holder.push_back (std::move (node_arg));
239+ }
240+
241+ // Create a name -> attribute map.
242+ NodeAttributes attributes;
243+ attributes.reserve (editor_node->attributes .size ());
244+
245+ for (const ONNX_NAMESPACE::AttributeProto& attr : editor_node->attributes ) {
246+ attributes.emplace (attr.name (), attr);
247+ }
248+
249+ // Create Node
250+ auto internal_node = std::make_unique<Node>(editor_node->node_name ,
251+ editor_node->operator_name ,
252+ " EPContext node for " + ep_name,
253+ input_node_args,
254+ output_node_args,
255+ &attributes,
256+ editor_node->domain_name );
257+
258+ ep_context_nodes_holder.push_back (std::move (internal_node));
259+ }
260+
261+ result_nodes = std::move (ep_context_nodes_holder);
262+ result_node_args = std::move (ep_context_node_args_holder);
263+
264+ return Status::OK ();
265+ #else
266+ ORT_UNUSED_PARAMETER (ep_name);
267+ ORT_UNUSED_PARAMETER (plugin_ep_context_nodes);
268+ ORT_UNUSED_PARAMETER (result_nodes);
269+ ORT_UNUSED_PARAMETER (result_node_args);
270+ return ORT_MAKE_STATUS (ONNXRUNTIME, NOT_IMPLEMENTED, " Creating EPContext models is not supported in this build" );
271+ #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
272+ }
273+
188274common::Status PluginExecutionProvider::Compile (const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
189275 std::vector<NodeComputeInfo>& node_compute_infos) {
190276 const logging::Logger* logger = GetLogger ();
@@ -220,8 +306,21 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
220306 api_fused_nodes.push_back (ep_fused_node->ToExternal ());
221307 }
222308
223- ORT_RETURN_IF_ERROR (ToStatusAndRelease (ort_ep_->Compile (ort_ep_.get (), api_graphs.data (), api_fused_nodes.data (),
224- num_graphs, api_node_compute_infos.data ())));
309+ // Provide an output buffer for the plugin EP to store EPContext nodes if it needs to (i.e., enabled in session options).
310+ std::vector<std::unique_ptr<OrtNode, decltype (&OrtApis::ReleaseNode)>> plugin_ep_context_nodes_holder;
311+ std::vector<OrtNode*> plugin_ep_context_nodes;
312+ plugin_ep_context_nodes_holder.reserve (num_graphs);
313+ plugin_ep_context_nodes.resize (num_graphs, nullptr );
314+
315+ Status compile_status = ToStatusAndRelease (ort_ep_->Compile (ort_ep_.get (), api_graphs.data (), api_fused_nodes.data (),
316+ num_graphs, api_node_compute_infos.data (),
317+ plugin_ep_context_nodes.data ()));
318+
319+ // Store any EPContext nodes provided by the plugin EP in std::unique_ptr so that they are always properly released.
320+ for (OrtNode* ort_node : plugin_ep_context_nodes) {
321+ auto unique_ort_node = std::unique_ptr<OrtNode, decltype (&OrtApis::ReleaseNode)>(ort_node, OrtApis::ReleaseNode);
322+ plugin_ep_context_nodes_holder.push_back (std::move (unique_ort_node));
323+ }
225324
226325 // Save OrtNodeComputeInfo created by OrtEp instance. They're freed when this IExecutionProvider
227326 // is destroyed.
@@ -231,6 +330,8 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
231330 }
232331 }
233332
333+ ORT_RETURN_IF_ERROR (compile_status);
334+
234335 // Initialize node_compute_infos as wrappers to api_node_compute_infos.
235336 for (size_t i = 0 ; i < num_graphs; i++) {
236337 OrtNodeComputeInfo* api_node_compute_info = api_node_compute_infos[i];
@@ -268,6 +369,25 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
268369 node_compute_infos.push_back (std::move (compute_info));
269370 }
270371
372+ // Convert the EPContext nodes provided by the plugin EP into onnxruntime::Node instances.
373+ // We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
374+ // partitioner via a call to IExecutionProvider::GetEpContextNodes().
375+ if (generate_ep_ctx_model_) {
376+ ORT_RETURN_IF_ERROR (ConvertEpContextNodes (Type (), plugin_ep_context_nodes,
377+ /* out*/ ep_context_nodes_, /* out*/ ep_context_node_args_));
378+ }
379+
271380 return Status::OK ();
272381}
382+
383+ const InlinedVector<const Node*> PluginExecutionProvider::GetEpContextNodes () const {
384+ InlinedVector<const Node*> result;
385+
386+ for (const std::unique_ptr<Node>& node : ep_context_nodes_) {
387+ result.push_back (node.get ());
388+ }
389+
390+ return result;
391+ }
392+
273393} // namespace onnxruntime
0 commit comments