Skip to content

Commit 8d53a48

Browse files
authored
Add getTranspositionSelectors && TranspositionInfo from 58ae6f0 to reduce merge conflicts (#5008)
Part of #4941 Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
1 parent 17ecb05 commit 8d53a48

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

include/triton/Analysis/Utility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ class GatherLoweringHelper {
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184184
// vector with the j-th lane basis vector.
185185
struct DecomposedWarpConversion {
186+
struct TranspositionInfo {
187+
std::pair<int, int> transposition;
188+
uint16_t topPreSel = 0x3210;
189+
uint16_t botPreSel = 0x7654;
190+
uint16_t topPostSel = 0x3210;
191+
uint16_t botPostSel = 0x7654;
192+
};
193+
186194
triton::LinearLayout pReg, pLane;
187195
SmallVector<std::pair<int, int>> mixedTranspositions;
188196
};

lib/Analysis/Utility.cpp

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "triton/Tools/LayoutUtils.h"
1818
#include "triton/Tools/LinearLayout.h"
1919
#include "triton/Tools/Sys/GetEnv.hpp"
20+
#include "llvm/ADT/SmallSet.h"
2021

2122
namespace mlir {
2223

@@ -253,6 +254,11 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
253254
return elementSizeInBytes * getScratchSizeInElems();
254255
}
255256

257+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
258+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
259+
std::vector<std::vector<int32_t>> &regBases,
260+
int bitwidth);
261+
256262
DecomposedWarpConversion
257263
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
258264
RankedTensorType dstTy) {
@@ -474,6 +480,226 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
474480
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
475481
}
476482

483+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
484+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
485+
std::vector<std::vector<int32_t>> &regBases,
486+
int bitwidth) {
487+
// When possible, we fuse permutations of 'low' register bits together
488+
// with a mixed transposition, resulting in byte permute instructions instead
489+
// of `select` instructions. After processing, no low register bits appear in
490+
// the returned list of mixed transpositions.
491+
int m = mixedTranspositions.size();
492+
int nRegBases = regBases.size();
493+
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
494+
int nPack = std::min(nPackPrelim, nRegBases - m);
495+
496+
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
497+
ret.reserve(mixedTranspositions.size());
498+
if (nPack == 0) {
499+
for (auto &t : mixedTranspositions)
500+
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
501+
return ret;
502+
}
503+
// Consider for example the cycle
504+
//
505+
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
506+
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
507+
//
508+
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
509+
// factor out any low bits from `pReg` and to incorporate them into the data
510+
// of the mixed transposition. After processing, the contribution to `pReg`
511+
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
512+
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
513+
// In general, low bits occurring immediately before l_j modify the selectors
514+
// of the `prmt` before the shuffle, while low bits occurring immediately
515+
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
516+
// selectors correspond to `select` instructions.
517+
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
518+
// not used in another mixed transposition and conjugating out a low bit:
519+
//
520+
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
521+
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
522+
//
523+
// Conjugation does not affect `pReg`. However, the set of fused mixed and
524+
// low-bit transpositions is noncommutative in cases where there are no
525+
// intervening high bits in between distinct sequences of lane bits as the
526+
// paired low bit is used in modifying the selectors of both factors:
527+
//
528+
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
529+
//
530+
// The `*` is standard composition of permutations. The groupings correspond
531+
// to different `TranspositionInfo` objects. For example, the permutation
532+
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
533+
// pre- and post-shuffle selectors determined by the `r0` bit.
534+
// Processing of mixed transpositions is performed by determining the `head`
535+
// and `tail` of an excision of bits in cycles of `pReg` and building lists
536+
// of low bits acting as selector modifiers. In the noncommutative cases, we
537+
// opt to restrict the number of post-shuffle modifiers to one.
538+
539+
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
540+
int lo = bitIdx + (2 - nPack);
541+
uint16_t maskHi = 0x4444;
542+
uint16_t maskLo = 0x1111 << lo;
543+
uint16_t fixed = sel & ~maskHi & ~maskLo;
544+
int shift = 2 - lo;
545+
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
546+
};
547+
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
548+
uint16_t topSel = 0x3210;
549+
uint16_t botSel = 0x7654;
550+
for (auto lowBit : lowBits) {
551+
topSel = permuteSelector(topSel, lowBit);
552+
botSel = permuteSelector(botSel, lowBit);
553+
if (lowBit != head && lowBit != tail)
554+
regBases[lowBit][0] = 1 << lowBit;
555+
}
556+
return std::pair{topSel, botSel};
557+
};
558+
559+
llvm::SmallSet<int32_t, 6> pairedRegBits;
560+
for (auto [rBit, lBit] : mixedTranspositions)
561+
pairedRegBits.insert(rBit);
562+
563+
// A low bit in a mixed transposition must be replaced by a high bit. The
564+
// choice of high bit can affect instruction count. If the first high bit
565+
// found when walking along `pReg` is unpaired, then that bit is the best
566+
// choice. We reorder the transpositions to guarantee this during processing.
567+
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
568+
auto nextHighFree = [&](auto p) {
569+
int curr = p.first;
570+
do {
571+
if (curr >= nPack)
572+
return curr == p.first || !pairedRegBits.contains(curr);
573+
curr = next(curr);
574+
} while (curr != p.first);
575+
return false;
576+
};
577+
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
578+
nextHighFree);
579+
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
580+
// bit to an open high bit, then the high bit should be used as the partner.
581+
auto prev = [&](int b) {
582+
int tail = b;
583+
int curr = next(b);
584+
while (curr != b) {
585+
tail = curr;
586+
curr = next(curr);
587+
}
588+
return tail;
589+
};
590+
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
591+
if (nPack == 2) {
592+
int otherLow = 1 - lowBit;
593+
int b = next(otherLow);
594+
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
595+
!pairedRegBits.contains(otherLow)) {
596+
preShufLoBits.push_back(otherLow);
597+
regBases[prev(otherLow)][0] = 1 << b;
598+
pairedRegBits.insert(b);
599+
return b;
600+
}
601+
}
602+
int potentialPartner = nPack;
603+
while (pairedRegBits.contains(potentialPartner))
604+
++potentialPartner;
605+
pairedRegBits.insert(potentialPartner);
606+
return potentialPartner;
607+
};
608+
609+
for (auto p : mixedTranspositions) {
610+
int rBit = p.first;
611+
int lBit = p.second;
612+
SmallVector<int> cycle;
613+
int currBit = rBit;
614+
do {
615+
cycle.push_back(currBit);
616+
currBit = next(currBit);
617+
} while (currBit != rBit);
618+
619+
// Find any low register bits adjacent to the excised lane bits which aren't
620+
// used in other mixed transpositions.
621+
auto isBoundary = [&](int bit) {
622+
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
623+
};
624+
auto forwardEnd = llvm::find_if(cycle, isBoundary);
625+
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
626+
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
627+
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
628+
int head;
629+
int tail;
630+
int partnerBit = -1;
631+
632+
// Case work to determine what to conjugate out.
633+
if (forwardEnd != cycle.end()) {
634+
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
635+
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
636+
// No conjugation needed.
637+
head = partnerBit = *forwardEnd;
638+
} else {
639+
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
640+
// Non-leading factor in a noncommutative case.
641+
// Conjugate by first low bit in forward walk.
642+
head = postShufLoBits.front();
643+
preShufLoBits.push_back(head);
644+
postShufLoBits.resize(1);
645+
pairedRegBits.erase(head);
646+
}
647+
tail = *backwardEnd;
648+
if (tail < nPack && pairedRegBits.contains(tail)) {
649+
// Non-terminal factor in a noncommutative case.
650+
preShufLoBits.insert(preShufLoBits.begin(), tail);
651+
}
652+
} else {
653+
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
654+
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
655+
preShufLoBits.erase(preShufLoBits.begin());
656+
postShufLoBits.pop_back();
657+
pairedRegBits.erase(postShufLoBits.front());
658+
head = rBit;
659+
tail = next(rBit);
660+
} else {
661+
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
662+
if (postShufLoBits.size() == 2)
663+
postShufLoBits.pop_back();
664+
head = tail = preShufLoBits.front();
665+
}
666+
}
667+
668+
if (partnerBit < 0)
669+
partnerBit = findPartner(head, preShufLoBits);
670+
auto [topPostSel, botPostSel] =
671+
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
672+
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
673+
regBases[tail][0] = 1 << head;
674+
675+
DecomposedWarpConversion::TranspositionInfo info;
676+
info.transposition = {partnerBit, lBit};
677+
info.topPreSel = topPreSel;
678+
info.botPreSel = botPreSel;
679+
info.topPostSel = topPostSel;
680+
info.botPostSel = botPostSel;
681+
682+
// In noncommutative cases, post-shuffle selectors of non-leading terms come
683+
// from a single low bit by design, so we can determine where to insert a
684+
// non-terminal factor by examining processed selectors.
685+
if (!preShufLoBits.empty()) {
686+
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
687+
auto it =
688+
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
689+
ret.insert(it, info);
690+
} else {
691+
ret.push_back(info);
692+
}
693+
}
694+
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
695+
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
696+
auto &t = ret.back();
697+
t.topPostSel = 0x3120;
698+
t.botPostSel = 0x7564;
699+
}
700+
return ret;
701+
}
702+
477703
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
478704
getReshapeDecomposition(ArrayRef<int64_t> srcShape,
479705
ArrayRef<int64_t> dstShape) {

0 commit comments

Comments
 (0)