Skip to content

Commit 9c9e3a6

Browse files
authored
Allow DML EP to be used with any CPU EP (microsoft#25664)
### Description <!-- Describe your changes. --> Relax restriction on DML EP so other CPU based EPs can be used. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> microsoft#25504
1 parent e0786fe commit 9c9e3a6

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,12 @@ class IExecutionProvider {
179179
/**
180180
Get the device id of current execution provider
181181
*/
182-
virtual int GetDeviceId() const { return default_device_.Id(); };
182+
virtual int GetDeviceId() const { return default_device_.Id(); }
183+
184+
/**
185+
* Get the OrtDevice the execution provider was registered with.
186+
*/
187+
const OrtDevice& GetDevice() const { return default_device_; }
183188

184189
/**
185190
Get execution provider's configuration options.

onnxruntime/core/session/inference_session.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,13 +1984,15 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
19841984
// For now, this function only checks for invalid combination of DML EP with other EPs.
19851985
// TODO: extend this function to check for other invalid combinations of EPs.
19861986
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
1987-
// DML EP is only allowed with CPU EP
1987+
// DML EP is not allowed with other GPU or NPU EPs.
1988+
// historical reason for this is unknown. relaxing the limit that it must only be used with the CPU EP to support
1989+
// scenarios where alternative EPs are CPU based (e.g. openvino).
19881990
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
19891991
if (has_dml_ep) {
1990-
const auto& ep_list = execution_providers_.GetIds();
1991-
for (const auto& ep : ep_list) {
1992-
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
1993-
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
1992+
for (const auto& ep : execution_providers_) {
1993+
if (ep->Type() != kDmlExecutionProvider && ep->GetDevice().Type() != OrtDevice::CPU) {
1994+
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can only be used with CPU EPs.");
1995+
}
19941996
}
19951997
}
19961998
return Status::OK();

0 commit comments

Comments
 (0)