Skip to content

Commit 131cf40

Browse files
authored
Update OrtEpFactory in MiGraphX EP (microsoft#25567)
### Description Update OrtEpFactory in new EPs to add allocator, data transfer and stream stubs. ### Motivation and Context
1 parent b957547 commit 131cf40

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,22 @@ struct MigraphXEpFactory : OrtEpFactory {
182182
OrtHardwareDeviceType hw_type,
183183
const OrtLogger& default_logger_in)
184184
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
185+
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with
185186
GetName = GetNameImpl;
186187
GetVendor = GetVendorImpl;
188+
GetVendorId = GetVendorIdImpl;
189+
GetVersion = GetVersionImpl;
190+
187191
GetSupportedDevices = GetSupportedDevicesImpl;
188192
CreateEp = CreateEpImpl;
189193
ReleaseEp = ReleaseEpImpl;
194+
195+
CreateAllocator = CreateAllocatorImpl;
196+
ReleaseAllocator = ReleaseAllocatorImpl;
197+
CreateDataTransfer = CreateDataTransferImpl;
198+
199+
IsStreamAware = IsStreamAwareImpl;
200+
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
190201
}
191202

192203
// Returns the name for the EP. Each unique factory configuration must have a unique name.
@@ -201,6 +212,16 @@ struct MigraphXEpFactory : OrtEpFactory {
201212
return factory->vendor.c_str();
202213
}
203214

215+
static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept {
216+
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
217+
return factory->vendor_id;
218+
}
219+
220+
static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept {
221+
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
222+
return factory->version.c_str();
223+
}
224+
204225
// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
205226
// An EP created with this factory is expected to be able to execute a model with *all* supported
206227
// hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among
@@ -245,10 +266,48 @@ struct MigraphXEpFactory : OrtEpFactory {
245266
// no-op as we never create an EP here.
246267
}
247268

269+
static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr,
270+
const OrtMemoryInfo* /*memory_info*/,
271+
const OrtKeyValuePairs* /*allocator_options*/,
272+
OrtAllocator** allocator) noexcept {
273+
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);
274+
275+
*allocator = nullptr;
276+
return factory->ort_api.CreateStatus(
277+
ORT_INVALID_ARGUMENT,
278+
"CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice.");
279+
}
280+
281+
static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept {
282+
// should never be called as we don't implement CreateAllocator
283+
}
284+
285+
static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/,
286+
OrtDataTransferImpl** data_transfer) noexcept {
287+
*data_transfer = nullptr; // not implemented
288+
return nullptr;
289+
}
290+
291+
static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
292+
return false;
293+
}
294+
295+
static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr,
296+
const OrtMemoryDevice* /*memory_device*/,
297+
const OrtKeyValuePairs* /*stream_options*/,
298+
OrtSyncStreamImpl** stream) noexcept {
299+
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);
300+
301+
*stream = nullptr;
302+
return factory->ort_api.CreateStatus(
303+
ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false.");
304+
}
305+
248306
const OrtApi& ort_api;
249307
const OrtLogger& default_logger;
250308
const std::string ep_name;
251309
const std::string vendor{"AMD"};
310+
const std::string version{"1.0.0"}; // MigraphX EP version
252311

253312
const uint32_t vendor_id{0x1002};
254313
const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice

0 commit comments

Comments
 (0)