@@ -1392,7 +1392,7 @@ template <typename T> T waveActiveBitAnd(T A, UINT) {
13921392WAVE_OP (OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize)));
13931393
13941394template <typename T> T waveActiveBitOr (T A, UINT) {
1395- // We set the LSB to 0 in one of the lanes.
1395+ // We set the LSB to 1 in one of the lanes.
13961396 return static_cast <T>(A | static_cast <T>(1 ));
13971397}
13981398
@@ -1405,6 +1405,60 @@ template <typename T> T waveActiveBitXor(T A, UINT) {
14051405
14061406WAVE_OP (OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize)));
14071407
1408+ WAVE_OP (OpType::WaveMultiPrefixBitAnd, waveMultiPrefixBitAnd(A, WaveSize));
1409+
1410+ template <typename T> T waveMultiPrefixBitAnd (T A, UINT) {
1411+ // All lanes in the group mask use a mask to filter for only the second and
1412+ // third LSBs.
1413+ return static_cast <T>(A & static_cast <T>(0x6 ));
1414+ }
1415+
1416+ WAVE_OP (OpType::WaveMultiPrefixBitOr, waveMultiPrefixBitOr(A, WaveSize));
1417+
1418+ template <typename T> T waveMultiPrefixBitOr (T A, UINT) {
1419+ // All lanes in the group mask clear the second LSB.
1420+ return static_cast <T>(A & ~static_cast <T>(0x2 ));
1421+ }
1422+
1423+ template <typename T>
1424+ struct Op <OpType::WaveMultiPrefixBitXor, T, 1 > : StrictValidation {};
1425+
1426+ template <typename T> struct ExpectedBuilder <OpType::WaveMultiPrefixBitXor, T> {
1427+ static std::vector<T> buildExpected (Op<OpType::WaveMultiPrefixBitXor, T, 1 > &,
1428+ const InputSets<T> &Inputs, UINT) {
1429+ DXASSERT_NOMSG (Inputs.size () == 1 );
1430+
1431+ std::vector<T> Expected;
1432+ const size_t VectorSize = Inputs[0 ].size ();
1433+
1434+ // We get a little creative for MultiPrefixBitXor. The mask we use for the
1435+ // group in the shader is 0xE (0b1110), which includes lanes 1, 2, and 3.
1436+ // Prefix ops don't include the value of the current lane in their result.
1437+ // So, for this test we store the result of WaveMultiPrefixBitXor from lane
1438+ // 3. This means only the values from lanes 1 and 2 contribute to the result
1439+ // at lane 3.
1440+ //
1441+ // In the shader:
1442+ // - Lane 0: Set to 0 (not in mask, shouldn't affect result)
1443+ // - Lane 1: Keeps original input values
1444+ // - Lane 2: Lower half + last element set to 0, upper half keeps input
1445+ // - Lane 3: Stores the prefix XOR result (lanes 1 XOR lanes 2)
1446+ //
1447+ // Expected result: Lower half matches input (lane 1 XOR 0), upper half is
1448+ // 0s, except last element matches input.
1449+ for (size_t I = 0 ; I < VectorSize / 2 ; ++I)
1450+ Expected.push_back (Inputs[0 ][I]);
1451+ for (size_t I = VectorSize / 2 ; I < VectorSize - 1 ; ++I)
1452+ Expected.push_back (0 );
1453+
1454+ // We also set the last element to 0 on lane 2 so the last element in the
1455+ // output vector matches the last element in the input vector.
1456+ Expected.push_back (Inputs[0 ][VectorSize - 1 ]);
1457+
1458+ return Expected;
1459+ }
1460+ };
1461+
14081462template <typename T>
14091463struct Op <OpType::WaveActiveAllEqual, T, 1 > : StrictValidation {};
14101464
@@ -1463,16 +1517,29 @@ template <typename T> struct ExpectedBuilder<OpType::WaveReadLaneFirst, T> {
14631517WAVE_OP (OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize)));
14641518
14651519template <typename T> T wavePrefixSum (T A, UINT WaveSize) {
1466- // We test the prefix sume in the 'middle' lane. This choice is arbitrary.
1467- return static_cast <T>(A * static_cast <T>(WaveSize / 2 ));
1520+ // We test the prefix sum in the 'middle' lane. This choice is arbitrary.
1521+ return A * static_cast <T>(WaveSize / 2 );
1522+ }
1523+
1524+ WAVE_OP (OpType::WaveMultiPrefixSum, (waveMultiPrefixSum(A, WaveSize)));
1525+
1526+ template <typename T> T waveMultiPrefixSum (T A, UINT) {
1527+ return A * static_cast <T>(2u );
14681528}
14691529
14701530WAVE_OP (OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize)));
14711531
14721532template <typename T> T wavePrefixProduct (T A, UINT) {
14731533 // We test the the prefix product in the 3rd lane to avoid overflow issues.
14741534 // So the result is A * A.
1475- return static_cast <T>(A * A);
1535+ return A * A;
1536+ }
1537+
1538+ WAVE_OP (OpType::WaveMultiPrefixProduct, (waveMultiPrefixProduct(A, WaveSize)));
1539+
1540+ template <typename T> T waveMultiPrefixProduct (T A, UINT) {
1541+ // The group mask has 3 lanes.
1542+ return A * A;
14761543}
14771544
14781545#undef WAVE_OP
@@ -2404,6 +2471,11 @@ class DxilConf_SM69_Vectorized {
24042471 HLK_WAVEOP_TEST (WaveReadLaneFirst, int16_t );
24052472 HLK_WAVEOP_TEST (WavePrefixSum, int16_t );
24062473 HLK_WAVEOP_TEST (WavePrefixProduct, int16_t );
2474+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, int16_t );
2475+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, int16_t );
2476+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, int16_t );
2477+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, int16_t );
2478+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, int16_t );
24072479 HLK_WAVEOP_TEST (WaveActiveSum, int32_t );
24082480 HLK_WAVEOP_TEST (WaveActiveMin, int32_t );
24092481 HLK_WAVEOP_TEST (WaveActiveMax, int32_t );
@@ -2412,7 +2484,12 @@ class DxilConf_SM69_Vectorized {
24122484 HLK_WAVEOP_TEST (WaveReadLaneAt, int32_t );
24132485 HLK_WAVEOP_TEST (WaveReadLaneFirst, int32_t );
24142486 HLK_WAVEOP_TEST (WavePrefixSum, int32_t );
2487+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, int32_t );
2488+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, int32_t );
24152489 HLK_WAVEOP_TEST (WavePrefixProduct, int32_t );
2490+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, int32_t );
2491+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, int32_t );
2492+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, int32_t );
24162493 HLK_WAVEOP_TEST (WaveActiveSum, int64_t );
24172494 HLK_WAVEOP_TEST (WaveActiveMin, int64_t );
24182495 HLK_WAVEOP_TEST (WaveActiveMax, int64_t );
@@ -2422,7 +2499,14 @@ class DxilConf_SM69_Vectorized {
24222499 HLK_WAVEOP_TEST (WaveReadLaneFirst, int64_t );
24232500 HLK_WAVEOP_TEST (WavePrefixSum, int64_t );
24242501 HLK_WAVEOP_TEST (WavePrefixProduct, int64_t );
2502+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, int64_t );
2503+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, int64_t );
2504+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, int64_t );
2505+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, int64_t );
2506+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, int64_t );
24252507
2508+ // Note: WaveActiveBit* ops don't support uint16_t in HLSL
2509+ // But the WaveMultiPrefixBit ops support all int and uint types
24262510 HLK_WAVEOP_TEST (WaveActiveSum, uint16_t );
24272511 HLK_WAVEOP_TEST (WaveActiveMin, uint16_t );
24282512 HLK_WAVEOP_TEST (WaveActiveMax, uint16_t );
@@ -2432,11 +2516,15 @@ class DxilConf_SM69_Vectorized {
24322516 HLK_WAVEOP_TEST (WaveReadLaneFirst, uint16_t );
24332517 HLK_WAVEOP_TEST (WavePrefixSum, uint16_t );
24342518 HLK_WAVEOP_TEST (WavePrefixProduct, uint16_t );
2519+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, uint16_t );
2520+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, uint16_t );
2521+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, uint16_t );
2522+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, uint16_t );
2523+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, uint16_t );
24352524 HLK_WAVEOP_TEST (WaveActiveSum, uint32_t );
24362525 HLK_WAVEOP_TEST (WaveActiveMin, uint32_t );
24372526 HLK_WAVEOP_TEST (WaveActiveMax, uint32_t );
24382527 HLK_WAVEOP_TEST (WaveActiveProduct, uint32_t );
2439- // Note: WaveActiveBit* ops don't support uint16_t in HLSL
24402528 HLK_WAVEOP_TEST (WaveActiveBitAnd, uint32_t );
24412529 HLK_WAVEOP_TEST (WaveActiveBitOr, uint32_t );
24422530 HLK_WAVEOP_TEST (WaveActiveBitXor, uint32_t );
@@ -2445,6 +2533,11 @@ class DxilConf_SM69_Vectorized {
24452533 HLK_WAVEOP_TEST (WaveReadLaneFirst, uint32_t );
24462534 HLK_WAVEOP_TEST (WavePrefixSum, uint32_t );
24472535 HLK_WAVEOP_TEST (WavePrefixProduct, uint32_t );
2536+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, uint32_t );
2537+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, uint32_t );
2538+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, uint32_t );
2539+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, uint32_t );
2540+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, uint32_t );
24482541 HLK_WAVEOP_TEST (WaveActiveSum, uint64_t );
24492542 HLK_WAVEOP_TEST (WaveActiveMin, uint64_t );
24502543 HLK_WAVEOP_TEST (WaveActiveMax, uint64_t );
@@ -2457,6 +2550,11 @@ class DxilConf_SM69_Vectorized {
24572550 HLK_WAVEOP_TEST (WaveReadLaneFirst, uint64_t );
24582551 HLK_WAVEOP_TEST (WavePrefixSum, uint64_t );
24592552 HLK_WAVEOP_TEST (WavePrefixProduct, uint64_t );
2553+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, uint64_t );
2554+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, uint64_t );
2555+ HLK_WAVEOP_TEST (WaveMultiPrefixBitAnd, uint64_t );
2556+ HLK_WAVEOP_TEST (WaveMultiPrefixBitOr, uint64_t );
2557+ HLK_WAVEOP_TEST (WaveMultiPrefixBitXor, uint64_t );
24602558
24612559 HLK_WAVEOP_TEST (WaveActiveSum, HLSLHalf_t);
24622560 HLK_WAVEOP_TEST (WaveActiveMin, HLSLHalf_t);
@@ -2467,6 +2565,8 @@ class DxilConf_SM69_Vectorized {
24672565 HLK_WAVEOP_TEST (WaveReadLaneFirst, HLSLHalf_t);
24682566 HLK_WAVEOP_TEST (WavePrefixSum, HLSLHalf_t);
24692567 HLK_WAVEOP_TEST (WavePrefixProduct, HLSLHalf_t);
2568+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, HLSLHalf_t);
2569+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, HLSLHalf_t);
24702570 HLK_WAVEOP_TEST (WaveActiveSum, float );
24712571 HLK_WAVEOP_TEST (WaveActiveMin, float );
24722572 HLK_WAVEOP_TEST (WaveActiveMax, float );
@@ -2476,6 +2576,8 @@ class DxilConf_SM69_Vectorized {
24762576 HLK_WAVEOP_TEST (WaveReadLaneFirst, float );
24772577 HLK_WAVEOP_TEST (WavePrefixSum, float );
24782578 HLK_WAVEOP_TEST (WavePrefixProduct, float );
2579+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, float );
2580+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, float );
24792581 HLK_WAVEOP_TEST (WaveActiveSum, double );
24802582 HLK_WAVEOP_TEST (WaveActiveMin, double );
24812583 HLK_WAVEOP_TEST (WaveActiveMax, double );
@@ -2485,6 +2587,8 @@ class DxilConf_SM69_Vectorized {
24852587 HLK_WAVEOP_TEST (WaveReadLaneFirst, double );
24862588 HLK_WAVEOP_TEST (WavePrefixSum, double );
24872589 HLK_WAVEOP_TEST (WavePrefixProduct, double );
2590+ HLK_WAVEOP_TEST (WaveMultiPrefixSum, double );
2591+ HLK_WAVEOP_TEST (WaveMultiPrefixProduct, double );
24882592
24892593private:
24902594 bool Initialized = false ;
0 commit comments