|
17 | 17 | #include "triton/Tools/LayoutUtils.h" |
18 | 18 | #include "triton/Tools/LinearLayout.h" |
19 | 19 | #include "triton/Tools/Sys/GetEnv.hpp" |
| 20 | +#include "llvm/ADT/SmallSet.h" |
20 | 21 |
|
21 | 22 | namespace mlir { |
22 | 23 |
|
@@ -253,6 +254,11 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() { |
253 | 254 | return elementSizeInBytes * getScratchSizeInElems(); |
254 | 255 | } |
255 | 256 |
|
| 257 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 258 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 259 | + std::vector<std::vector<int32_t>> ®Bases, |
| 260 | + int bitwidth); |
| 261 | + |
256 | 262 | DecomposedWarpConversion |
257 | 263 | getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
258 | 264 | RankedTensorType dstTy) { |
@@ -474,6 +480,226 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
474 | 480 | return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)}; |
475 | 481 | } |
476 | 482 |
|
| 483 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 484 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 485 | + std::vector<std::vector<int32_t>> ®Bases, |
| 486 | + int bitwidth) { |
| 487 | + // When possible, we fuse permutations of 'low' register bits together |
| 488 | + // with a mixed transposition, resulting in byte permute instructions instead |
| 489 | + // of `select` instructions. After processing, no low register bits appear in |
| 490 | + // the returned list of mixed transpositions. |
| 491 | + int m = mixedTranspositions.size(); |
| 492 | + int nRegBases = regBases.size(); |
| 493 | + int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); |
| 494 | + int nPack = std::min(nPackPrelim, nRegBases - m); |
| 495 | + |
| 496 | + SmallVector<DecomposedWarpConversion::TranspositionInfo> ret; |
| 497 | + ret.reserve(mixedTranspositions.size()); |
| 498 | + if (nPack == 0) { |
| 499 | + for (auto &t : mixedTranspositions) |
| 500 | + ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); |
| 501 | + return ret; |
| 502 | + } |
| 503 | + // Consider for example the cycle |
| 504 | + // |
| 505 | + // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) |
| 506 | + // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) |
| 507 | + // |
| 508 | + // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to |
| 509 | + // factor out any low bits from `pReg` and to incorporate them into the data |
| 510 | + // of the mixed transposition. After processing, the contribution to `pReg` |
| 511 | + // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with |
| 512 | + // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. |
| 513 | + // In general, low bits occurring immediately before l_j modify the selectors |
| 514 | + // of the `prmt` before the shuffle, while low bits occurring immediately |
| 515 | + // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified |
| 516 | + // selectors correspond to `select` instructions. |
| 517 | + // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is |
| 518 | + // not used in another mixed transposition and conjugating out a low bit: |
| 519 | + // |
| 520 | + // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) |
| 521 | + // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). |
| 522 | + // |
| 523 | + // Conjugation does not affect `pReg`. However, the set of fused mixed and |
| 524 | + // low-bit transpositions is noncommutative in cases where there are no |
| 525 | + // intervening high bits in between distinct sequences of lane bits as the |
| 526 | + // paired low bit is used in modifying the selectors of both factors: |
| 527 | + // |
| 528 | + // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). |
| 529 | + // |
| 530 | + // The `*` is standard composition of permutations. The groupings correspond |
| 531 | + // to different `TranspositionInfo` objects. For example, the permutation |
| 532 | + // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with |
| 533 | + // pre- and post-shuffle selectors determined by the `r0` bit. |
| 534 | + // Processing of mixed transpositions is performed by determining the `head` |
| 535 | + // and `tail` of an excision of bits in cycles of `pReg` and building lists |
| 536 | + // of low bits acting as selector modifiers. In the noncommutative cases, we |
| 537 | + // opt to restrict the number of post-shuffle modifiers to one. |
| 538 | + |
| 539 | + auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { |
| 540 | + int lo = bitIdx + (2 - nPack); |
| 541 | + uint16_t maskHi = 0x4444; |
| 542 | + uint16_t maskLo = 0x1111 << lo; |
| 543 | + uint16_t fixed = sel & ~maskHi & ~maskLo; |
| 544 | + int shift = 2 - lo; |
| 545 | + return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); |
| 546 | + }; |
| 547 | + auto generateSelectors = [&](int head, int tail, auto &&lowBits) { |
| 548 | + uint16_t topSel = 0x3210; |
| 549 | + uint16_t botSel = 0x7654; |
| 550 | + for (auto lowBit : lowBits) { |
| 551 | + topSel = permuteSelector(topSel, lowBit); |
| 552 | + botSel = permuteSelector(botSel, lowBit); |
| 553 | + if (lowBit != head && lowBit != tail) |
| 554 | + regBases[lowBit][0] = 1 << lowBit; |
| 555 | + } |
| 556 | + return std::pair{topSel, botSel}; |
| 557 | + }; |
| 558 | + |
| 559 | + llvm::SmallSet<int32_t, 6> pairedRegBits; |
| 560 | + for (auto [rBit, lBit] : mixedTranspositions) |
| 561 | + pairedRegBits.insert(rBit); |
| 562 | + |
| 563 | + // A low bit in a mixed transposition must be replaced by a high bit. The |
| 564 | + // choice of high bit can affect instruction count. If the first high bit |
| 565 | + // found when walking along `pReg` is unpaired, then that bit is the best |
| 566 | + // choice. We reorder the transpositions to guarantee this during processing. |
| 567 | + auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; |
| 568 | + auto nextHighFree = [&](auto p) { |
| 569 | + int curr = p.first; |
| 570 | + do { |
| 571 | + if (curr >= nPack) |
| 572 | + return curr == p.first || !pairedRegBits.contains(curr); |
| 573 | + curr = next(curr); |
| 574 | + } while (curr != p.first); |
| 575 | + return false; |
| 576 | + }; |
| 577 | + std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), |
| 578 | + nextHighFree); |
| 579 | + // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low |
| 580 | + // bit to an open high bit, then the high bit should be used as the partner. |
| 581 | + auto prev = [&](int b) { |
| 582 | + int tail = b; |
| 583 | + int curr = next(b); |
| 584 | + while (curr != b) { |
| 585 | + tail = curr; |
| 586 | + curr = next(curr); |
| 587 | + } |
| 588 | + return tail; |
| 589 | + }; |
| 590 | + auto findPartner = [&](int lowBit, auto &preShufLoBits) { |
| 591 | + if (nPack == 2) { |
| 592 | + int otherLow = 1 - lowBit; |
| 593 | + int b = next(otherLow); |
| 594 | + if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && |
| 595 | + !pairedRegBits.contains(otherLow)) { |
| 596 | + preShufLoBits.push_back(otherLow); |
| 597 | + regBases[prev(otherLow)][0] = 1 << b; |
| 598 | + pairedRegBits.insert(b); |
| 599 | + return b; |
| 600 | + } |
| 601 | + } |
| 602 | + int potentialPartner = nPack; |
| 603 | + while (pairedRegBits.contains(potentialPartner)) |
| 604 | + ++potentialPartner; |
| 605 | + pairedRegBits.insert(potentialPartner); |
| 606 | + return potentialPartner; |
| 607 | + }; |
| 608 | + |
| 609 | + for (auto p : mixedTranspositions) { |
| 610 | + int rBit = p.first; |
| 611 | + int lBit = p.second; |
| 612 | + SmallVector<int> cycle; |
| 613 | + int currBit = rBit; |
| 614 | + do { |
| 615 | + cycle.push_back(currBit); |
| 616 | + currBit = next(currBit); |
| 617 | + } while (currBit != rBit); |
| 618 | + |
| 619 | + // Find any low register bits adjacent to the excised lane bits which aren't |
| 620 | + // used in other mixed transpositions. |
| 621 | + auto isBoundary = [&](int bit) { |
| 622 | + return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); |
| 623 | + }; |
| 624 | + auto forwardEnd = llvm::find_if(cycle, isBoundary); |
| 625 | + auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); |
| 626 | + SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd); |
| 627 | + SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd); |
| 628 | + int head; |
| 629 | + int tail; |
| 630 | + int partnerBit = -1; |
| 631 | + |
| 632 | + // Case work to determine what to conjugate out. |
| 633 | + if (forwardEnd != cycle.end()) { |
| 634 | + if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { |
| 635 | + // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) |
| 636 | + // No conjugation needed. |
| 637 | + head = partnerBit = *forwardEnd; |
| 638 | + } else { |
| 639 | + // End at different paired bit. E.g. (l0 r0 r1 l1 r2) |
| 640 | + // Non-leading factor in a noncommutative case. |
| 641 | + // Conjugate by first low bit in forward walk. |
| 642 | + head = postShufLoBits.front(); |
| 643 | + preShufLoBits.push_back(head); |
| 644 | + postShufLoBits.resize(1); |
| 645 | + pairedRegBits.erase(head); |
| 646 | + } |
| 647 | + tail = *backwardEnd; |
| 648 | + if (tail < nPack && pairedRegBits.contains(tail)) { |
| 649 | + // Non-terminal factor in a noncommutative case. |
| 650 | + preShufLoBits.insert(preShufLoBits.begin(), tail); |
| 651 | + } |
| 652 | + } else { |
| 653 | + if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { |
| 654 | + // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) |
| 655 | + preShufLoBits.erase(preShufLoBits.begin()); |
| 656 | + postShufLoBits.pop_back(); |
| 657 | + pairedRegBits.erase(postShufLoBits.front()); |
| 658 | + head = rBit; |
| 659 | + tail = next(rBit); |
| 660 | + } else { |
| 661 | + // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) |
| 662 | + if (postShufLoBits.size() == 2) |
| 663 | + postShufLoBits.pop_back(); |
| 664 | + head = tail = preShufLoBits.front(); |
| 665 | + } |
| 666 | + } |
| 667 | + |
| 668 | + if (partnerBit < 0) |
| 669 | + partnerBit = findPartner(head, preShufLoBits); |
| 670 | + auto [topPostSel, botPostSel] = |
| 671 | + generateSelectors(head, tail, llvm::reverse(postShufLoBits)); |
| 672 | + auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); |
| 673 | + regBases[tail][0] = 1 << head; |
| 674 | + |
| 675 | + DecomposedWarpConversion::TranspositionInfo info; |
| 676 | + info.transposition = {partnerBit, lBit}; |
| 677 | + info.topPreSel = topPreSel; |
| 678 | + info.botPreSel = botPreSel; |
| 679 | + info.topPostSel = topPostSel; |
| 680 | + info.botPostSel = botPostSel; |
| 681 | + |
| 682 | + // In noncommutative cases, post-shuffle selectors of non-leading terms come |
| 683 | + // from a single low bit by design, so we can determine where to insert a |
| 684 | + // non-terminal factor by examining processed selectors. |
| 685 | + if (!preShufLoBits.empty()) { |
| 686 | + uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; |
| 687 | + auto it = |
| 688 | + llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); |
| 689 | + ret.insert(it, info); |
| 690 | + } else { |
| 691 | + ret.push_back(info); |
| 692 | + } |
| 693 | + } |
| 694 | + if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { |
| 695 | + // If (r0 r1) was originally in `P`, fold it into a mixed transposition. |
| 696 | + auto &t = ret.back(); |
| 697 | + t.topPostSel = 0x3120; |
| 698 | + t.botPostSel = 0x7564; |
| 699 | + } |
| 700 | + return ret; |
| 701 | +} |
| 702 | + |
477 | 703 | SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> |
478 | 704 | getReshapeDecomposition(ArrayRef<int64_t> srcShape, |
479 | 705 | ArrayRef<int64_t> dstShape) { |
|
0 commit comments