Skip to content

Commit 0dd71ba

Browse files
Retrieve Device and Command buffer for DML (microsoft#25533)
### Description Retrieve Device and Command buffer for DML to inject other gfx command into the same device and command buffer for efficient GPU pipelining ### Motivation and Context Retrieve Device and Command buffer for DML to inject other gfx command into the same device and command buffer for efficient GPU pipelining
1 parent 08e18b2 commit 0dd71ba

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

include/onnxruntime/core/providers/dml/dml_provider_factory.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ struct OrtDmlApi {
151151
* (high power, low power, or default) and a device filter (None, GPU, or NPU).
152152
*/
153153
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
154+
155+
/**
156+
* GetDMLDevice
157+
* returns the DML device attched to the DML execution provider
158+
*/
159+
ORT_API2_STATUS(GetDMLDevice, _In_ OrtSessionOptions* options, _Out_ IDMLDevice** dmlDevice);
160+
161+
/**
162+
* GetDMLCommandQueue
163+
* returns the command queue used by DML
164+
*/
165+
ORT_API2_STATUS(GetDMLCommandQueue, _In_ OrtSessionOptions* options, _Out_ ID3D12CommandQueue** dmlCommandQueue);
154166
};
155167

156168
#ifdef __cplusplus

onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,6 @@ namespace Dml
5252
void RegisterDmlOperators(IMLOperatorRegistry* registry);
5353
void RegisterCpuOperatorsAsDml(onnxruntime::KernelRegistry* registry);
5454

55+
void getDMLDevice(onnxruntime::IExecutionProvider* provider, _COM_Outptr_ IDMLDevice** dmlDevice);
56+
5557
} // namespace Dml

onnxruntime/core/providers/dml/dml_provider_factory.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ struct DMLProviderFactory : IExecutionProviderFactory {
6565

6666
void SetMetacommandsEnabled(bool metacommands_enabled);
6767

68+
IDMLDevice* GetDMLDevice();
69+
ID3D12CommandQueue* GetDMLCommandQueue();
70+
6871
private:
6972
ComPtr<IDMLDevice> dml_device_{};
7073
ComPtr<ID3D12CommandQueue> cmd_queue_{};
@@ -101,6 +104,14 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) {
101104
metacommands_enabled_ = metacommands_enabled;
102105
}
103106

107+
IDMLDevice* DMLProviderFactory::GetDMLDevice() {
108+
return dml_device_.Get();
109+
}
110+
111+
ID3D12CommandQueue* DMLProviderFactory::GetDMLCommandQueue() {
112+
return cmd_queue_.Get();
113+
}
114+
104115
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(const ConfigOptions& config_options,
105116
IDMLDevice* dml_device,
106117
ID3D12CommandQueue* cmd_queue,
@@ -712,13 +723,49 @@ ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_alloc
712723
API_IMPL_END
713724
}
714725

726+
ORT_API_STATUS_IMPL(GetDMLDevice, _In_ OrtSessionOptions* options, _Out_ IDMLDevice** dmlDevice) {
727+
API_IMPL_BEGIN
728+
729+
*dmlDevice = nullptr;
730+
#ifdef USE_DML
731+
if (options) {
732+
for (auto& factory : options->provider_factories) {
733+
if (auto dml_provider_factory = static_cast<onnxruntime::DMLProviderFactory*>(factory.get())) {
734+
*dmlDevice = dml_provider_factory->GetDMLDevice();
735+
}
736+
}
737+
}
738+
#endif // USE_DML
739+
return nullptr;
740+
API_IMPL_END
741+
}
742+
743+
ORT_API_STATUS_IMPL(GetDMLCommandQueue, _In_ OrtSessionOptions* options, _Out_ ID3D12CommandQueue** dmlCommandQ) {
744+
API_IMPL_BEGIN
745+
746+
*dmlCommandQ = nullptr;
747+
#ifdef USE_DML
748+
if (options) {
749+
for (auto& factory : options->provider_factories) {
750+
if (auto dml_provider_factory = static_cast<onnxruntime::DMLProviderFactory*>(factory.get())) {
751+
*dmlCommandQ = dml_provider_factory->GetDMLCommandQueue();
752+
}
753+
}
754+
}
755+
#endif // USE_DML
756+
return nullptr;
757+
API_IMPL_END
758+
}
759+
715760
static constexpr OrtDmlApi ort_dml_api_10_to_x = {
716761
&OrtSessionOptionsAppendExecutionProvider_DML,
717762
&OrtSessionOptionsAppendExecutionProviderEx_DML,
718763
&CreateGPUAllocationFromD3DResource,
719764
&FreeGPUAllocation,
720765
&GetD3D12ResourceFromAllocation,
721766
&OrtSessionOptionsAppendExecutionProvider_DML2,
767+
&GetDMLDevice,
768+
&GetDMLCommandQueue,
722769
};
723770

724771
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION {

0 commit comments

Comments
 (0)