Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/clang/unittests/HLSLExec/LongVectorOps.def
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,6 @@ OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_W
OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1")
OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1")
OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1")
OP_DEFAULT_DEFINES(Wave, WaveMatch, 1, "TestWaveMatch", "", " -DFUNC_WAVE_MATCH=1 -DIS_WAVE_PREFIX_OP=1")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-DIS_WAVE_PREFIX_OP=1 is required, the test function returns void, since it handles writing to the out vector inside the test function, instead of delegating to main


#undef OP
56 changes: 56 additions & 0 deletions tools/clang/unittests/HLSLExec/LongVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,52 @@ template <typename T> T wavePrefixProduct(T A, UINT) {
return static_cast<T>(A * A);
}

template <typename T> struct Op<OpType::WaveMatch, T, 1> : StrictValidation {};

template <typename T> struct ExpectedBuilder<OpType::WaveMatch, T> {
static std::vector<UINT> buildExpected(Op<OpType::WaveMatch, T, 1> &,
const InputSets<T> &Inputs,
UINT WaveSize) {
DXASSERT_NOMSG(Inputs.size() == 1);

std::vector<UINT> Expected;
const size_t VectorSize = Inputs[0].size();
Expected.assign(VectorSize, 0);

UINT wordShift = WaveSize / 32;
UINT bitShift = WaveSize % 32;
if (bitShift == 0) {
bitShift = 32;
wordShift--;
}
uint64_t result[4] = {((1ULL << bitShift) - 1) & ~1ULL, 0, 0, 0};

for (UINT I = wordShift; I > 0; I--)
result[I] = ~0ULL;

Expected[0] = 1;
Expected[1] = 0;
Expected[2] = 0;
if (VectorSize <= 3)
return Expected;

Expected[3] = 0;
for (UINT I = 1; I < WaveSize; I++) {
const UINT Index = I * 4;
if (Index < VectorSize)
Expected[Index] = static_cast<UINT>(result[0]);
if (Index + 1 < VectorSize)
Expected[Index + 1] = static_cast<UINT>(result[1]);
if (Index + 2 < VectorSize)
Expected[Index + 2] = static_cast<UINT>(result[2]);
if (Index + 3 < VectorSize)
Expected[Index + 3] = static_cast<UINT>(result[3]);
}

return Expected;
}
};

#undef WAVE_OP

//
Expand Down Expand Up @@ -2334,6 +2380,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLBool_t);
HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLBool_t);
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLBool_t);
HLK_WAVEOP_TEST(WaveMatch, HLSLBool_t);

HLK_WAVEOP_TEST(WaveActiveSum, int16_t);
HLK_WAVEOP_TEST(WaveActiveMin, int16_t);
Expand All @@ -2344,6 +2391,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t);
HLK_WAVEOP_TEST(WavePrefixSum, int16_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int16_t);
HLK_WAVEOP_TEST(WaveMatch, int16_t);
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
HLK_WAVEOP_TEST(WaveActiveMin, int32_t);
HLK_WAVEOP_TEST(WaveActiveMax, int32_t);
Expand All @@ -2353,6 +2401,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t);
HLK_WAVEOP_TEST(WavePrefixSum, int32_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int32_t);
HLK_WAVEOP_TEST(WaveMatch, int32_t);
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);
HLK_WAVEOP_TEST(WaveActiveMin, int64_t);
HLK_WAVEOP_TEST(WaveActiveMax, int64_t);
Expand All @@ -2362,6 +2411,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t);
HLK_WAVEOP_TEST(WavePrefixSum, int64_t);
HLK_WAVEOP_TEST(WavePrefixProduct, int64_t);
HLK_WAVEOP_TEST(WaveMatch, int64_t);

HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint16_t);
Expand All @@ -2372,6 +2422,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint16_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t);
HLK_WAVEOP_TEST(WaveMatch, uint16_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint32_t);
HLK_WAVEOP_TEST(WaveActiveMax, uint32_t);
Expand All @@ -2385,6 +2436,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint32_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t);
HLK_WAVEOP_TEST(WaveMatch, uint32_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);
HLK_WAVEOP_TEST(WaveActiveMin, uint64_t);
HLK_WAVEOP_TEST(WaveActiveMax, uint64_t);
Expand All @@ -2397,6 +2449,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t);
HLK_WAVEOP_TEST(WavePrefixSum, uint64_t);
HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t);
HLK_WAVEOP_TEST(WaveMatch, uint64_t);

HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t);
Expand All @@ -2407,6 +2460,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t);
HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveMatch, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveActiveSum, float);
HLK_WAVEOP_TEST(WaveActiveMin, float);
HLK_WAVEOP_TEST(WaveActiveMax, float);
Expand All @@ -2416,6 +2470,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, float);
HLK_WAVEOP_TEST(WavePrefixSum, float);
HLK_WAVEOP_TEST(WavePrefixProduct, float);
HLK_WAVEOP_TEST(WaveMatch, float);
HLK_WAVEOP_TEST(WaveActiveSum, double);
HLK_WAVEOP_TEST(WaveActiveMin, double);
HLK_WAVEOP_TEST(WaveActiveMax, double);
Expand All @@ -2425,6 +2480,7 @@ class DxilConf_SM69_Vectorized {
HLK_WAVEOP_TEST(WaveReadLaneFirst, double);
HLK_WAVEOP_TEST(WavePrefixSum, double);
HLK_WAVEOP_TEST(WavePrefixProduct, double);
HLK_WAVEOP_TEST(WaveMatch, double);

private:
bool Initialized = false;
Expand Down
22 changes: 22 additions & 0 deletions tools/clang/unittests/HLSLExec/ShaderOpArith.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4215,6 +4215,28 @@ void MSMain(uint GID : SV_GroupIndex,
}
#endif

#ifdef FUNC_WAVE_MATCH
void TestWaveMatch(vector<TYPE, NUM> Vector)
{
if(WaveGetLaneIndex() == 0)
{
if(Vector[0] == (TYPE)0)
Vector[0] = (TYPE) 1;
else if(Vector[0] == (TYPE)1)
Vector[0] = (TYPE) 0;
else
Vector[0] = (TYPE) 1;
}
uint4 result = WaveMatch(Vector);
uint index = WaveGetLaneIndex() * 4;

g_OutputVector.Store<uint>(index * sizeof(uint), result.x);
g_OutputVector.Store<uint>((index + 1) * sizeof(uint), result.y);
g_OutputVector.Store<uint>((index + 2) * sizeof(uint), result.z);
g_OutputVector.Store<uint>((index + 3) * sizeof(uint), result.w);
}
#endif

#ifdef FUNC_TEST_SELECT
vector<OUT_TYPE, NUM> TestSelect(vector<TYPE, NUM> Vector1,
vector<TYPE, NUM> Vector2,
Expand Down