Skip to content

Commit cc875ee

Browse files
committed
Hook up ID3D12DeviceFactory-style
1 parent 4fb6998 commit cc875ee

File tree

1 file changed

+125
-28
lines changed

1 file changed

+125
-28
lines changed

tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp

Lines changed: 125 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <atlcomcli.h>
99
#include <d3d12.h>
1010
#include <dxgi1_4.h>
11+
#include <filesystem>
1112
#include <optional>
1213

1314
static bool useDebugIfaces() { return true; }
@@ -80,7 +81,10 @@ static std::wstring getModuleName() {
8081
return std::wstring(ModuleName, Length);
8182
}
8283

83-
static std::wstring computeSDKFullPath(std::wstring SDKPath) {
84+
static std::wstring computeSDKFullPath(const std::wstring &SDKPath) {
85+
if (std::filesystem::path(SDKPath).is_absolute())
86+
return SDKPath;
87+
8488
std::wstring ModulePath = getModuleName();
8589
const size_t Pos = ModulePath.rfind('\\');
8690

@@ -94,6 +98,8 @@ static std::wstring computeSDKFullPath(std::wstring SDKPath) {
9498
}
9599

96100
static UINT getD3D12SDKVersion(std::wstring SDKPath) {
101+
using namespace hlsl_test;
102+
97103
// Try to automatically get the D3D12SDKVersion from the DLL
98104
UINT SDKVersion = 0;
99105
std::wstring D3DCorePath = computeSDKFullPath(SDKPath);
@@ -104,13 +110,21 @@ static UINT getD3D12SDKVersion(std::wstring SDKPath) {
104110
(UINT *)GetProcAddress(D3DCore, "D3D12SDKVersion"))
105111
SDKVersion = *SDKVersionOut;
106112
FreeModule(D3DCore);
113+
LogCommentFmt(L"%s - D3D12SDKVersion is %d", D3DCorePath.c_str(),
114+
SDKVersion);
115+
} else {
116+
LogCommentFmt(L"%s - unable to load", D3DCorePath.c_str());
107117
}
108118
return SDKVersion;
109119
}
110120

111-
bool createDevice(ID3D12Device **D3DDevice,
112-
ExecTestUtils::D3D_SHADER_MODEL TestModel,
113-
bool SkipUnsupported) {
121+
bool createDevice(
122+
ID3D12Device **D3DDevice, ExecTestUtils::D3D_SHADER_MODEL TestModel,
123+
bool SkipUnsupported,
124+
std::function<HRESULT(IUnknown *, D3D_FEATURE_LEVEL, REFIID, void **)>
125+
CreateDevice
126+
127+
) {
114128
if (TestModel > ExecTestUtils::D3D_HIGHEST_SHADER_MODEL) {
115129
const UINT Minor = (UINT)TestModel & 0x0f;
116130
hlsl_test::LogCommentFmt(L"Installed SDK does not support "
@@ -159,8 +173,8 @@ bool createDevice(ID3D12Device **D3DDevice,
159173
// Create the WARP device
160174
CComPtr<IDXGIAdapter> WarpAdapter;
161175
VERIFY_SUCCEEDED(DXGIFactory->EnumWarpAdapter(IID_PPV_ARGS(&WarpAdapter)));
162-
HRESULT CreateHR = D3D12CreateDevice(WarpAdapter, D3D_FEATURE_LEVEL_11_0,
163-
IID_PPV_ARGS(&D3DDeviceCom));
176+
HRESULT CreateHR = CreateDevice(WarpAdapter, D3D_FEATURE_LEVEL_11_0,
177+
IID_PPV_ARGS(&D3DDeviceCom));
164178
if (FAILED(CreateHR)) {
165179
hlsl_test::LogCommentFmt(
166180
L"The available version of WARP does not support d3d12.");
@@ -196,8 +210,8 @@ bool createDevice(ID3D12Device **D3DDevice,
196210
WEX::Logging::Log::Comment(
197211
L"Using default hardware adapter with D3D12 support.");
198212

199-
VERIFY_SUCCEEDED(D3D12CreateDevice(HardwareAdapter, D3D_FEATURE_LEVEL_11_0,
200-
IID_PPV_ARGS(&D3DDeviceCom)));
213+
VERIFY_SUCCEEDED(CreateDevice(HardwareAdapter, D3D_FEATURE_LEVEL_11_0,
214+
IID_PPV_ARGS(&D3DDeviceCom)));
201215
}
202216
// retrieve adapter information
203217
const LUID AdapterID = D3DDeviceCom->GetAdapterLuid();
@@ -224,8 +238,8 @@ bool createDevice(ID3D12Device **D3DDevice,
224238
SMData.HighestShaderModel < TestModel) {
225239
const UINT Minor = (UINT)TestModel & 0x0f;
226240
hlsl_test::LogCommentFmt(L"The selected device does not support "
227-
L"shader model 6.%1u",
228-
Minor);
241+
L"shader model 6.%1u (highest is 6.%1u)",
242+
Minor, SMData.HighestShaderModel & 0x0f);
229243

230244
if (SkipUnsupported)
231245
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
@@ -458,15 +472,15 @@ static std::optional<AgilitySDKConfiguration> getAgilitySDKConfiguration() {
458472
return C;
459473
}
460474

461-
static bool enableGlobalAgilitySDK() {
475+
static bool
476+
enableGlobalAgilitySDK(const std::optional<AgilitySDKConfiguration> &C) {
462477
using namespace hlsl_test;
463478

464-
std::optional<AgilitySDKConfiguration> C = getAgilitySDKConfiguration();
465479
if (!C)
466480
return false;
467481

468482
if (C->SDKVersion == 0)
469-
return true;
483+
return false;
470484

471485
CComPtr<ID3D12SDKConfiguration> SDKConfig;
472486
HRESULT HR;
@@ -498,12 +512,14 @@ static bool enableGlobalAgilitySDK() {
498512
return true;
499513
}
500514

515+
static bool isExperimentalShadersEnabled() {
516+
return hlsl_test::GetTestParamBool(L"ExperimentalShaders");
517+
}
518+
501519
static bool enableGlobalExperimentalMode() {
502520
using namespace hlsl_test;
503521

504-
const bool ExperimentalShaders = GetTestParamBool(L"ExperimentalShaders");
505-
506-
if (!ExperimentalShaders)
522+
if (!isExperimentalShadersEnabled())
507523
return false;
508524

509525
HRESULT HR;
@@ -518,10 +534,11 @@ static bool enableGlobalExperimentalMode() {
518534
return true;
519535
}
520536

521-
static void setGlobalConfiguration() {
537+
static void
538+
setGlobalConfiguration(const std::optional<AgilitySDKConfiguration> &C) {
522539
using namespace hlsl_test;
523540

524-
if (enableGlobalAgilitySDK())
541+
if (enableGlobalAgilitySDK(C))
525542
LogCommentFmt(L"Agility SDK enabled.");
526543
else
527544
LogCommentFmt(L"Agility SDK not enabled.");
@@ -532,6 +549,58 @@ static void setGlobalConfiguration() {
532549
LogCommentFmt(L"Experimental mode not enabled.");
533550
}
534551

552+
static bool enableExperimentalMode(ID3D12DeviceFactory *DeviceFactory) {
553+
using namespace hlsl_test;
554+
555+
if (!isExperimentalShadersEnabled())
556+
return false;
557+
558+
HRESULT HR;
559+
if (FAILED(HR = DeviceFactory->EnableExperimentalFeatures(
560+
1, &D3D12ExperimentalShaderModels, nullptr, nullptr))) {
561+
LogWarningFmt(L"EnableExperimentalFeature(D3D12ExperimentalShaderModels) "
562+
L"failed: 0x%08x",
563+
HR);
564+
return false;
565+
}
566+
567+
return true;
568+
}
569+
570+
static CComPtr<ID3D12DeviceFactory>
571+
createDeviceFactorySDK(const AgilitySDKConfiguration &C) {
572+
using namespace hlsl_test;
573+
574+
HRESULT HR;
575+
576+
CComPtr<ID3D12SDKConfiguration1> SDKConfig;
577+
if (FAILED(HR = D3D12GetInterface(CLSID_D3D12SDKConfiguration,
578+
IID_PPV_ARGS(&SDKConfig)))) {
579+
LogCommentFmt(L"Failed to get ID3D12SDKConfiguration1 interface: 0x%08x",
580+
HR);
581+
return nullptr;
582+
}
583+
584+
CComPtr<ID3D12DeviceFactory> DeviceFactory;
585+
if (FAILED(
586+
HR = SDKConfig->CreateDeviceFactory(C.SDKVersion, CW2A(C.SDKPath),
587+
IID_PPV_ARGS(&DeviceFactory)))) {
588+
LogCommentFmt(L"CreateDeviceFactory(%d, '%s', ...) failed: 0x%08x",
589+
C.SDKVersion, static_cast<const wchar_t *>(C.SDKPath), HR);
590+
return nullptr;
591+
}
592+
593+
LogCommentFmt(L"Using DeviceFactory for SDKVersion %d, SDKPath %s",
594+
C.SDKVersion, static_cast<const wchar_t *>(C.SDKPath));
595+
596+
if (enableExperimentalMode(DeviceFactory))
597+
LogCommentFmt(L"Experimental mode enabled.");
598+
else
599+
LogCommentFmt(L"Experimental mode not enabled.");
600+
601+
return DeviceFactory;
602+
}
603+
535604
std::optional<D3D12SDK> D3D12SDK::create() {
536605
using namespace hlsl_test;
537606

@@ -540,25 +609,53 @@ std::optional<D3D12SDK> D3D12SDK::create() {
540609
else
541610
LogCommentFmt(L"Debug layer not enabled");
542611

543-
// CComPtr<ID3D12SDKConfiguration1> Config1;
544-
// if (FAILED(D3D12GetInterface(CLSID_D3D12SDKConfiguration,
545-
// IID_PPV_ARGS(&Config1))))
546-
{ return D3D12SDK(nullptr); }
547-
548-
CComPtr<ID3D12DeviceFactory> DeviceFactory;
612+
std::optional<AgilitySDKConfiguration> C = getAgilitySDKConfiguration();
549613

550-
// ...
614+
if (C && C->SDKVersion > 0) {
615+
CComPtr<ID3D12DeviceFactory> DeviceFactory = createDeviceFactorySDK(*C);
616+
if (DeviceFactory)
617+
return D3D12SDK(DeviceFactory);
618+
}
551619

552-
return D3D12SDK(DeviceFactory);
620+
setGlobalConfiguration(C);
621+
return D3D12SDK(nullptr);
553622
}
554623

555624
D3D12SDK::D3D12SDK(CComPtr<ID3D12DeviceFactory> DeviceFactory)
556625
: DeviceFactory(std::move(DeviceFactory)) {}
557626

558-
D3D12SDK::~D3D12SDK() = default;
627+
D3D12SDK::~D3D12SDK() {
628+
using namespace hlsl_test;
629+
630+
if (DeviceFactory) {
631+
DeviceFactory.Release();
632+
633+
HRESULT HR;
634+
CComPtr<ID3D12SDKConfiguration1> SDKConfig;
635+
if (FAILED(HR = D3D12GetInterface(CLSID_D3D12SDKConfiguration,
636+
IID_PPV_ARGS(&SDKConfig)))) {
637+
LogCommentFmt(L"Failed to get ID3D12SDKConfiguration1 interface: 0x%08x",
638+
HR);
639+
return;
640+
}
641+
642+
SDKConfig->FreeUnusedSDKs();
643+
}
644+
}
559645

560646
bool D3D12SDK::createDevice(ID3D12Device **D3DDevice,
561647
ExecTestUtils::D3D_SHADER_MODEL TestModel,
562648
bool SkipUnsupported) {
563-
return ::createDevice(D3DDevice, TestModel, SkipUnsupported);
649+
650+
if (DeviceFactory) {
651+
hlsl_test::LogCommentFmt(L"Creating device using DeviceFactory");
652+
return ::createDevice(
653+
D3DDevice, TestModel, SkipUnsupported,
654+
[&](IUnknown *A, D3D_FEATURE_LEVEL FL, REFIID R, void **P) {
655+
return DeviceFactory->CreateDevice(A, FL, R, P);
656+
});
657+
}
658+
659+
return ::createDevice(D3DDevice, TestModel, SkipUnsupported,
660+
D3D12CreateDevice);
564661
}

0 commit comments

Comments
 (0)