diff --git a/tools/clang/unittests/HLSLExec/LongVectorOps.def b/tools/clang/unittests/HLSLExec/LongVectorOps.def index c9fc281246..c9df9b8f28 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorOps.def +++ b/tools/clang/unittests/HLSLExec/LongVectorOps.def @@ -219,5 +219,6 @@ OP(Wave, WaveMultiPrefixProduct, 1, "TestWaveMultiPrefixProduct", "", " -DFUNC_W OP(Wave, WaveMultiPrefixBitAnd, 1, "TestWaveMultiPrefixBitAnd", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_AND=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) OP(Wave, WaveMultiPrefixBitOr, 1, "TestWaveMultiPrefixBitOr", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_OR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) OP(Wave, WaveMultiPrefixBitXor, 1, "TestWaveMultiPrefixBitXor", "", " -DFUNC_WAVE_MULTI_PREFIX_BIT_XOR=1 -DIS_WAVE_PREFIX_OP=1", "LongVectorOp", WaveMultiPrefixBitwise, Default2, Default3) +OP_DEFAULT_DEFINES(Wave, WaveMatch, 1, "TestWaveMatch", "", " -DFUNC_WAVE_MATCH=1 -DIS_WAVE_PREFIX_OP=1") #undef OP diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index b646b6a4b9..f84c97ebaa 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -1542,6 +1542,49 @@ template T waveMultiPrefixProduct(T A, UINT) { return A * A; } +template struct Op : StrictValidation {}; + +template struct ExpectedBuilder { + static std::vector buildExpected(Op &, + const InputSets &, + const UINT WaveSize) { + // For this test, the shader arranges it so that lane 0 is different from + // all the other lanes. Besides that all other lines write their result of + // WaveMatch as well. + + std::vector Expected; + Expected.assign(WaveSize * 4, 0); + + const UINT LowWaves = std::min(64U, WaveSize); + const UINT HighWaves = WaveSize - LowWaves; + + const uint64_t LowWaveMask = + (LowWaves < 64) ? (1ULL << LowWaves) - 1 : ~0ULL; + + const uint64_t HighWaveMask = + (HighWaves < 64) ? (1ULL << HighWaves) - 1 : ~0ULL; + + const uint64_t LowExpected = ~1ULL & LowWaveMask; + const uint64_t HighExpected = ~0ULL & HighWaveMask; + + Expected[0] = 1; + Expected[1] = 0; + Expected[2] = 0; + Expected[3] = 0; + + // all lanes other than the first one have the same result + for (UINT I = 1; I < WaveSize; ++I) { + const UINT Index = I * 4; + Expected[Index] = static_cast(LowExpected); + Expected[Index + 1] = static_cast(LowExpected >> 32); + Expected[Index + 2] = static_cast(HighExpected); + Expected[Index + 3] = static_cast(HighExpected >> 32); + } + + return Expected; + } +}; + #undef WAVE_OP // @@ -2461,6 +2504,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLBool_t); HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLBool_t); HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLBool_t); + HLK_WAVEOP_TEST(WaveMatch, HLSLBool_t); HLK_WAVEOP_TEST(WaveActiveSum, int16_t); HLK_WAVEOP_TEST(WaveActiveMin, int16_t); @@ -2476,6 +2520,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int16_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int16_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int16_t); + HLK_WAVEOP_TEST(WaveMatch, int16_t); HLK_WAVEOP_TEST(WaveActiveSum, int32_t); HLK_WAVEOP_TEST(WaveActiveMin, int32_t); HLK_WAVEOP_TEST(WaveActiveMax, int32_t); @@ -2490,6 +2535,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int32_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int32_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int32_t); + HLK_WAVEOP_TEST(WaveMatch, int32_t); HLK_WAVEOP_TEST(WaveActiveSum, int64_t); HLK_WAVEOP_TEST(WaveActiveMin, int64_t); HLK_WAVEOP_TEST(WaveActiveMax, int64_t); @@ -2504,6 +2550,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, int64_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, int64_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, int64_t); + HLK_WAVEOP_TEST(WaveMatch, int64_t); // Note: WaveActiveBit* ops don't support uint16_t in HLSL // But the WaveMultiPrefixBit ops support all int and uint types @@ -2521,6 +2568,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint16_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint16_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint16_t); + HLK_WAVEOP_TEST(WaveMatch, uint16_t); HLK_WAVEOP_TEST(WaveActiveSum, uint32_t); HLK_WAVEOP_TEST(WaveActiveMin, uint32_t); HLK_WAVEOP_TEST(WaveActiveMax, uint32_t); @@ -2538,6 +2586,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint32_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint32_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint32_t); + HLK_WAVEOP_TEST(WaveMatch, uint32_t); HLK_WAVEOP_TEST(WaveActiveSum, uint64_t); HLK_WAVEOP_TEST(WaveActiveMin, uint64_t); HLK_WAVEOP_TEST(WaveActiveMax, uint64_t); @@ -2555,6 +2604,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WaveMultiPrefixBitAnd, uint64_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitOr, uint64_t); HLK_WAVEOP_TEST(WaveMultiPrefixBitXor, uint64_t); + HLK_WAVEOP_TEST(WaveMatch, uint64_t); HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t); @@ -2567,6 +2617,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t); HLK_WAVEOP_TEST(WaveMultiPrefixSum, HLSLHalf_t); HLK_WAVEOP_TEST(WaveMultiPrefixProduct, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveMatch, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveSum, float); HLK_WAVEOP_TEST(WaveActiveMin, float); HLK_WAVEOP_TEST(WaveActiveMax, float); @@ -2578,6 +2629,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WavePrefixProduct, float); HLK_WAVEOP_TEST(WaveMultiPrefixSum, float); HLK_WAVEOP_TEST(WaveMultiPrefixProduct, float); + HLK_WAVEOP_TEST(WaveMatch, float); HLK_WAVEOP_TEST(WaveActiveSum, double); HLK_WAVEOP_TEST(WaveActiveMin, double); HLK_WAVEOP_TEST(WaveActiveMax, double); @@ -2589,6 +2641,7 @@ class DxilConf_SM69_Vectorized { HLK_WAVEOP_TEST(WavePrefixProduct, double); HLK_WAVEOP_TEST(WaveMultiPrefixSum, double); HLK_WAVEOP_TEST(WaveMultiPrefixProduct, double); + HLK_WAVEOP_TEST(WaveMatch, double); private: bool Initialized = false; diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index 4bac1dddd1..968f8993a1 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -4405,6 +4405,25 @@ void MSMain(uint GID : SV_GroupIndex, } #endif + #ifdef FUNC_WAVE_MATCH + void TestWaveMatch(vector Vector) + { + if(WaveGetLaneIndex() == 0) + { + if(Vector[0] == (TYPE)0) + Vector[0] = (TYPE) 1; + else if(Vector[0] == (TYPE)1) + Vector[0] = (TYPE) 0; + else + Vector[0] = (TYPE) 1; + } + uint4 result = WaveMatch(Vector); + uint index = WaveGetLaneIndex(); + + g_OutputVector.Store(index * sizeof(uint4), result); + } + #endif + #ifdef FUNC_TEST_SELECT vector TestSelect(vector Vector1, vector Vector2,