Skip to content

Commit f04df24

Browse files
authored
[AMD][NFC] Refactor AccelerateAMDMatmul operand legalization (intel#4136)
- Choose proper configuration of operands according to number of conversions if it possible; - Get rid of complicated logic to find operand config; - Remove helper `supportWMMA()` to get rid of impicit logical dependencies with AccelerateAMDMatmul Signed-off-by: Ilya Veselov <iveselov.nn@gmail.com>
1 parent a06add0 commit f04df24

File tree

2 files changed

+118
-99
lines changed

2 files changed

+118
-99
lines changed

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -482,58 +482,6 @@ bool supportMFMA(triton::DotOp op) {
482482
return true;
483483
}
484484

485-
static bool supportWMMAGranularity(int m, int n, int k) {
486-
return m % 16 == 0 && n % 16 == 0 && k % 16 == 0;
487-
}
488-
489-
static bool supportWMMATypes(Type a, Type b, Type c, Type d) {
490-
if (a != b || c != d)
491-
return false;
492-
auto aWidth = a.getIntOrFloatBitWidth();
493-
auto cWidth = c.getIntOrFloatBitWidth();
494-
if (a.isIntOrIndex()) {
495-
if (!c.isIntOrIndex())
496-
return false;
497-
bool aValid = aWidth <= 8;
498-
bool cValid = cWidth <= 32;
499-
return aValid && cValid;
500-
} else if (isa<FloatType>(a) && isa<FloatType>(c)) {
501-
if (a.isBF16())
502-
return c.isBF16() || c.isF32();
503-
if (a.isF16())
504-
return c.isF16() || c.isF32();
505-
return aWidth <= cWidth && aWidth <= 16;
506-
}
507-
return false;
508-
}
509-
510-
bool supportWMMA(triton::DotOp op) {
511-
auto aTy = cast<RankedTensorType>(op.getA().getType());
512-
auto bTy = cast<RankedTensorType>(op.getB().getType());
513-
auto cTy = cast<RankedTensorType>(op.getC().getType());
514-
auto dTy = cast<RankedTensorType>(op.getResult().getType());
515-
516-
auto aElemTy = aTy.getElementType();
517-
auto bElemTy = bTy.getElementType();
518-
auto cElemTy = cTy.getElementType();
519-
auto dElemTy = dTy.getElementType();
520-
521-
if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy))
522-
return false;
523-
524-
auto aShape = aTy.getShape();
525-
auto bShape = bTy.getShape();
526-
527-
auto rank = aShape.size();
528-
assert(bShape.size() == rank);
529-
assert(aShape[rank - 1] == bShape[rank - 2]);
530-
if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1],
531-
aShape[rank - 1]))
532-
return false;
533-
534-
return true;
535-
}
536-
537485
bool supportMMA(triton::DotOp op, int version) {
538486
// Refer to mma section for the data type supported by Volta and Hopper
539487
// Tensor Core in

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 118 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,87 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
112112
AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr()[1]});
113113
}
114114

115+
using OperandTypesVector = SmallVector<Type, 4>;
116+
OperandTypesVector
117+
selectMatrixCoreOperandTypes(tt::DotOp dot,
118+
ArrayRef<OperandTypesVector> applicableTypes) {
119+
SmallVector<Value> dotOperands = {dot.getA(), dot.getB(), dot.getC(),
120+
dot.getD()};
121+
OperandTypesVector initElemTypes;
122+
llvm::transform(dotOperands, std::back_inserter(initElemTypes), [](Value v) {
123+
return cast<RankedTensorType>(v.getType()).getElementType();
124+
});
125+
126+
// Use simple costmodel to define optimal set of the dot operands.
127+
// Most expensive - accuracy loss conversions:
128+
// - any larger type -> any smaller type;
129+
// - float -> int;
130+
// - int -> float (not supported for now);
131+
// - signed int -> unsigned int;
132+
// - unsigned int -> signed int with same or less size.
133+
// They are never performed, better to use FMA.
134+
// Supported conversion for now costs `1`, no conversion costs `0`.
135+
// The model could be improved in the future. For example taken into account
136+
// chain dot could be detected and result conversion score is decreased.
137+
int maxConvertCost =
138+
std::numeric_limits<int32_t>::max() / applicableTypes.front().size();
139+
auto calcConvertCost = [&](Type fromTy, Type toTy) -> int32_t {
140+
if (fromTy == toTy)
141+
return 0;
142+
143+
// Skip conversion between int and float. Int16/int32 cases are lowered to
144+
// FMA.
145+
if (fromTy.isIntOrIndex() != toTy.isIntOrIndex())
146+
return maxConvertCost;
147+
148+
if (fromTy.isIntOrIndex() && toTy.isIntOrIndex() &&
149+
fromTy.isUnsignedInteger() != toTy.isUnsignedInteger())
150+
return fromTy.isUnsignedInteger() && fromTy.getIntOrFloatBitWidth() <
151+
toTy.getIntOrFloatBitWidth()
152+
? 1
153+
: maxConvertCost;
154+
155+
return fromTy.getIntOrFloatBitWidth() <= toTy.getIntOrFloatBitWidth()
156+
? 1
157+
: maxConvertCost;
158+
};
159+
auto minCost = maxConvertCost;
160+
auto optTypes = OperandTypesVector();
161+
for (auto types : applicableTypes) {
162+
assert(types.size() == initElemTypes.size());
163+
int accumulatedConvertCost = 0;
164+
for (int i = 0; i < initElemTypes.size(); ++i) {
165+
accumulatedConvertCost += calcConvertCost(initElemTypes[i], types[i]);
166+
}
167+
if (accumulatedConvertCost < minCost) {
168+
minCost = accumulatedConvertCost;
169+
optTypes = types;
170+
}
171+
}
172+
return optTypes;
173+
}
174+
175+
OperandTypesVector getOperandTypesForWmmaOp(mlir::PatternRewriter &rewriter,
176+
tt::DotOp dot) {
177+
Type f16 = rewriter.getF16Type();
178+
Type f32 = rewriter.getF32Type();
179+
Type bf16 = rewriter.getBF16Type();
180+
Type i8 = rewriter.getIntegerType(8);
181+
Type i32 = rewriter.getIntegerType(32);
182+
SmallVector<OperandTypesVector> applicableTypes = {
183+
// clang-format off
184+
{f16, f16, f32, f32},
185+
{f16, f16, f16, f16},
186+
{bf16, bf16, f32, f32},
187+
{bf16, bf16, bf16, bf16},
188+
{i8, i8, i32, i32},
189+
// i4, i4, i32, i32 - is supported configuration
190+
// by WMMA instruction, but not supported by triton
191+
// clang-format on
192+
};
193+
return selectMatrixCoreOperandTypes(dot, applicableTypes);
194+
}
195+
115196
/**
116197
* @brief Convert layout and cast element type of a given tensor
117198
*
@@ -520,81 +601,71 @@ class BlockedToWMMA : public mlir::RewritePattern {
520601
mlir::LogicalResult
521602
matchAndRewrite(mlir::Operation *op,
522603
mlir::PatternRewriter &rewriter) const override {
604+
auto ctx = op->getContext();
523605
auto dotOp = cast<tt::DotOp>(op);
524606

607+
Value a = dotOp.getA();
608+
Value b = dotOp.getB();
609+
525610
auto oldRetType = cast<RankedTensorType>(dotOp.getResult().getType());
526-
if (!oldRetType.getEncoding() ||
527-
!isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding()))
611+
auto oldRetEncoding = oldRetType.getEncoding();
612+
if (!oldRetEncoding || !isa<ttg::BlockedEncodingAttr>(oldRetEncoding))
613+
return failure();
614+
615+
auto oldAType = cast<RankedTensorType>(a.getType());
616+
auto oldBType = cast<RankedTensorType>(b.getType());
617+
auto retShape = oldRetType.getShape();
618+
auto aShape = oldAType.getShape();
619+
auto bShape = oldBType.getShape();
620+
621+
// check shape
622+
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
623+
auto rank = aShape.size();
624+
if (aShape[rank - 2] % mnkDim[0] != 0 || // m
625+
bShape[rank - 1] % mnkDim[1] != 0 || // n
626+
aShape[rank - 1] % mnkDim[2] != 0) // k
528627
return failure();
529628

530-
if (!supportWMMA(dotOp))
629+
// get operand types
630+
auto operandTypes = getOperandTypesForWmmaOp(rewriter, dotOp);
631+
if (operandTypes.empty())
531632
return failure();
532633

533634
// get WMMA encoding for the given number of warps
534-
auto retShape = oldRetType.getShape();
535635
auto mod = op->getParentOfType<mlir::ModuleOp>();
536636
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
537637

538-
// operands
539-
Value a = dotOp.getA();
540-
Value b = dotOp.getB();
541-
auto oldAType = cast<RankedTensorType>(a.getType());
542-
auto oldBType = cast<RankedTensorType>(b.getType());
543-
auto ctx = oldAType.getContext();
544-
545638
AMDWmmaEncodingAttr wmmaEnc;
546639

547-
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
548640
auto warpsPerTile = warpsPerTileWMMA(dotOp, retShape, numWarps);
549-
// Not supported yet
550-
// if (retShape[0] < warpsPerTile[0] * mnkDim[0] || retShape[1] <
551-
// warpsPerTile[1] * mnkDim[1])
552-
// return failure();
553-
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
554-
wmmaEnc = AMDWmmaEncodingAttr::get(oldRetType.getContext(), warpsPerTile,
555-
CTALayout);
556-
557-
Type wmmaAccType;
558-
auto oldRetElemType = oldRetType.getElementType();
559-
auto aElemType = oldAType.getElementType();
560-
auto bElemType = oldBType.getElementType();
561-
if (oldRetElemType.isIntOrIndex()) {
562-
wmmaAccType = rewriter.getIntegerType(32);
563-
} else if (isa<mlir::Float16Type, mlir::BFloat16Type>(oldRetElemType) &&
564-
aElemType == oldRetElemType) {
565-
wmmaAccType = oldRetElemType;
566-
} else if (isa<mlir::FloatType>(oldRetElemType) &&
567-
aElemType.getIntOrFloatBitWidth() < 16) {
568-
aElemType = rewriter.getF16Type();
569-
bElemType = rewriter.getF16Type();
570-
wmmaAccType = rewriter.getF16Type();
571-
} else {
572-
wmmaAccType = rewriter.getF32Type();
573-
}
574641

575-
auto newRetType = RankedTensorType::get(retShape, wmmaAccType, wmmaEnc);
642+
auto CTALayout = ttg::getCTALayout(oldRetEncoding);
643+
wmmaEnc = AMDWmmaEncodingAttr::get(ctx, warpsPerTile, CTALayout);
644+
645+
auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc);
576646

577647
// convert accumulator
578648
auto oldAcc = dotOp.getOperand(2);
579-
auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, wmmaAccType);
649+
auto newAcc =
650+
convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]);
580651

581652
auto newAType = RankedTensorType::get(
582-
oldAType.getShape(), aElemType,
653+
aShape, operandTypes[0],
583654
ttg::DotOperandEncodingAttr::get(ctx, 0, wmmaEnc, mnkDim[2]));
584655
auto newBType = RankedTensorType::get(
585-
oldBType.getShape(), bElemType,
656+
bShape, operandTypes[1],
586657
ttg::DotOperandEncodingAttr::get(ctx, 1, wmmaEnc, mnkDim[2]));
587658

588-
Value castedA =
589-
convertAndCastTensor(rewriter, a, newAType.getEncoding(), aElemType);
590-
Value castedB =
591-
convertAndCastTensor(rewriter, b, newBType.getEncoding(), bElemType);
659+
Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(),
660+
operandTypes[0]);
661+
Value castedB = convertAndCastTensor(rewriter, b, newBType.getEncoding(),
662+
operandTypes[1]);
592663
auto newDot = rewriter.create<tt::DotOp>(
593664
dotOp.getLoc(), newRetType, castedA, castedB, newAcc,
594665
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
595666

596-
Value dotOutput = convertAndCastTensor(
597-
rewriter, newDot, oldRetType.getEncoding(), oldRetElemType);
667+
Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding,
668+
oldRetType.getElementType());
598669
rewriter.replaceOp(op, dotOutput);
599670
return success();
600671
}

0 commit comments

Comments
 (0)