88#include < atlcomcli.h>
99#include < d3d12.h>
1010#include < dxgi1_4.h>
11+ #include < optional>
1112
1213static bool useDebugIfaces () { return true ; }
1314
@@ -319,6 +320,7 @@ enableExperimentalShaderModels(UUID AdditionalFeatures[] = nullptr,
319320}
320321
321322HRESULT enableAgilitySDK (HMODULE Runtime) {
323+
322324 // D3D12SDKVersion > 1 will use provided version, otherwise, auto-detect.
323325 // D3D12SDKVersion == 1 means fail if we can't auto-detect.
324326 UINT SDKVersion = 0 ;
@@ -390,17 +392,173 @@ HRESULT enableExperimentalMode(HMODULE Runtime) {
390392 return HR;
391393}
392394
393- HRESULT enableDebugLayer () {
394- // The debug layer does net yet validate DXIL programs that require
395- // rewriting, but basic logging should work properly.
396- HRESULT HR = S_FALSE;
397- if (useDebugIfaces ()) {
398- CComPtr<ID3D12Debug> DebugController;
399- HR = D3D12GetDebugInterface (IID_PPV_ARGS (&DebugController));
400- if (SUCCEEDED (HR)) {
401- DebugController->EnableDebugLayer ();
402- HR = S_OK;
395+ static bool enableDebugLayer () {
396+ using namespace hlsl_test ;
397+
398+ CComPtr<ID3D12Debug> DebugController;
399+ HRESULT HR;
400+ if (FAILED (HR = D3D12GetDebugInterface (IID_PPV_ARGS (&DebugController)))) {
401+ LogErrorFmt (L" Failed to get ID3D12Debug: 0x%08x" , HR);
402+ return false ;
403+ }
404+
405+ DebugController->EnableDebugLayer ();
406+ return true ;
407+ }
408+
409+ struct AgilitySDKConfiguration {
410+ WEX::Common::String SDKPath;
411+ UINT SDKVersion = 0 ;
412+ bool MustFind = false ;
413+ };
414+
415+ static std::optional<AgilitySDKConfiguration> getAgilitySDKConfiguration () {
416+ using hlsl_test::LogErrorFmt;
417+ using WEX::TestExecution::RuntimeParameters;
418+
419+ AgilitySDKConfiguration C;
420+
421+ // For global configuration, D3D12SDKPath must be relative path from .exe,
422+ // which means relative to TE.exe location, and must start with ".\\", such as
423+ // with the default: ".\\D3D12\\".
424+ //
425+ // For ID3D12DeviceFactory-style configuration, D3D12SDKPath can be an
426+ // absolute path.
427+ if (SUCCEEDED (RuntimeParameters::TryGetValue (L" D3D12SDKPath" , C.SDKPath ))) {
428+ // Make sure path ends in backslash
429+ if (!C.SDKPath .IsEmpty () && C.SDKPath .Right (1 ) != " \\ " )
430+ C.SDKPath .Append (" \\ " );
431+ }
432+
433+ if (C.SDKPath .IsEmpty ())
434+ C.SDKPath = L" .\\ D3D12\\ " ;
435+
436+ // D3D12SDKVersion > 1 will use provided version, otherwise, auto-detect.
437+ // D3D12SDKVersion == 1 means fail if we can't auto-detect.
438+ RuntimeParameters::TryGetValue (L" D3D12SDKVersion" , C.SDKVersion );
439+
440+ C.MustFind = C.SDKVersion >= 1 ;
441+
442+ if (C.SDKVersion <= 1 ) {
443+ // Use the version supported by the SDK in the path.
444+ C.SDKVersion = getD3D12SDKVersion (std::wstring (C.SDKPath ));
445+ if (C.SDKVersion == 0 ) {
446+ if (C.MustFind ) {
447+ LogErrorFmt (L" Agility SDK not found in relative path: %s" ,
448+ static_cast <const wchar_t *>(C.SDKPath ));
449+ return std::nullopt ;
450+ }
451+
452+ // No AgilitySDK found, caller indicated that they just want to use the
453+ // inbox D3D12 in this case.
454+ return AgilitySDKConfiguration{};
403455 }
404456 }
405- return HR;
457+
458+ return C;
459+ }
460+
461+ static bool enableGlobalAgilitySDK () {
462+ using namespace hlsl_test ;
463+
464+ std::optional<AgilitySDKConfiguration> C = getAgilitySDKConfiguration ();
465+ if (!C)
466+ return false ;
467+
468+ if (C->SDKVersion == 0 )
469+ return true ;
470+
471+ CComPtr<ID3D12SDKConfiguration> SDKConfig;
472+ HRESULT HR;
473+ if (FAILED (HR = D3D12GetInterface (CLSID_D3D12SDKConfiguration,
474+ IID_PPV_ARGS (&SDKConfig)))) {
475+ LogErrorFmt (L" Failed to get ID3D12SDKConfiguration instance: 0x%08x" , HR);
476+ return !C->MustFind ;
477+ }
478+
479+ if (FAILED (HR = SDKConfig->SetSDKVersion (C->SDKVersion , CW2A (C->SDKPath )))) {
480+ LogErrorFmt (L" SetSDKVersion(%d, %s) failed: 0x%08x" , C->SDKVersion ,
481+ static_cast <const wchar_t *>(C->SDKPath ), HR);
482+ return !C->MustFind ;
483+ }
484+
485+ // Currently, it appears that the SetSDKVersion will succeed even when
486+ // D3D12Core is not found, or its version doesn't match. When that's the
487+ // case, will cause a failure in the very next thing that actually requires
488+ // D3D12Core.dll to be loaded instead. So, we attempt to clear experimental
489+ // features next, which is a valid use case and a no-op at this point. This
490+ // requires D3D12Core to be loaded. If this fails, we know the AgilitySDK
491+ // setting actually failed.
492+ if (FAILED (
493+ HR = D3D12EnableExperimentalFeatures (0 , nullptr , nullptr , nullptr ))) {
494+ LogErrorFmt (L" D3D12EnableExperimentalFeatures(0...) failed: 0x%08x" , HR);
495+ return !C->MustFind ;
496+ }
497+
498+ return true ;
406499}
500+
501+ static bool enableGlobalExperimentalMode () {
502+ using namespace hlsl_test ;
503+
504+ const bool ExperimentalShaders = GetTestParamBool (L" ExperimentalShaders" );
505+
506+ if (!ExperimentalShaders)
507+ return false ;
508+
509+ HRESULT HR;
510+ if (FAILED (HR = D3D12EnableExperimentalFeatures (
511+ 1 , &D3D12ExperimentalShaderModels, nullptr , nullptr ))) {
512+ LogErrorFmt (L" D3D12EnableExperimentalFeatures("
513+ L" D3D12ExperimentalShaderModels) failed: 0x%08x" ,
514+ HR);
515+ return false ;
516+ }
517+
518+ return true ;
519+ }
520+
521+ static void setGlobalConfiguration () {
522+ using namespace hlsl_test ;
523+
524+ if (enableGlobalAgilitySDK ())
525+ LogCommentFmt (L" Agility SDK enabled." );
526+ else
527+ LogCommentFmt (L" Agility SDK not enabled." );
528+
529+ if (enableGlobalExperimentalMode ())
530+ LogCommentFmt (L" Experimental mode enabled." );
531+ else
532+ LogCommentFmt (L" Experimental mode not enabled." );
533+ }
534+
535+ std::optional<D3D12SDK> D3D12SDK::create () {
536+ using namespace hlsl_test ;
537+
538+ if (enableDebugLayer ())
539+ LogCommentFmt (L" Debug layer enabled" );
540+ else
541+ LogCommentFmt (L" Debug layer not enabled" );
542+
543+ // CComPtr<ID3D12SDKConfiguration1> Config1;
544+ // if (FAILED(D3D12GetInterface(CLSID_D3D12SDKConfiguration,
545+ // IID_PPV_ARGS(&Config1))))
546+ { return D3D12SDK (nullptr ); }
547+
548+ CComPtr<ID3D12DeviceFactory> DeviceFactory;
549+
550+ // ...
551+
552+ return D3D12SDK (DeviceFactory);
553+ }
554+
555+ D3D12SDK::D3D12SDK (CComPtr<ID3D12DeviceFactory> DeviceFactory)
556+ : DeviceFactory(std::move(DeviceFactory)) {}
557+
558+ D3D12SDK::~D3D12SDK () = default ;
559+
560+ bool D3D12SDK::createDevice (ID3D12Device **D3DDevice,
561+ ExecTestUtils::D3D_SHADER_MODEL TestModel,
562+ bool SkipUnsupported) {
563+ return ::createDevice (D3DDevice, TestModel, SkipUnsupported);
564+ }
0 commit comments