@@ -112,6 +112,87 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
112112 AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr ()[1 ]});
113113}
114114
115+ using OperandTypesVector = SmallVector<Type, 4 >;
116+ OperandTypesVector
117+ selectMatrixCoreOperandTypes (tt::DotOp dot,
118+ ArrayRef<OperandTypesVector> applicableTypes) {
119+ SmallVector<Value> dotOperands = {dot.getA (), dot.getB (), dot.getC (),
120+ dot.getD ()};
121+ OperandTypesVector initElemTypes;
122+ llvm::transform (dotOperands, std::back_inserter (initElemTypes), [](Value v) {
123+ return cast<RankedTensorType>(v.getType ()).getElementType ();
124+ });
125+
126+ // Use simple costmodel to define optimal set of the dot operands.
127+ // Most expensive - accuracy loss conversions:
128+ // - any larger type -> any smaller type;
129+ // - float -> int;
130+ // - int -> float (not supported for now);
131+ // - signed int -> unsigned int;
132+ // - unsigned int -> signed int with same or less size.
133+ // They are never performed, better to use FMA.
134+ // Supported conversion for now costs `1`, no conversion costs `0`.
135+ // The model could be improved in the future. For example taken into account
136+ // chain dot could be detected and result conversion score is decreased.
137+ int maxConvertCost =
138+ std::numeric_limits<int32_t >::max () / applicableTypes.front ().size ();
139+ auto calcConvertCost = [&](Type fromTy, Type toTy) -> int32_t {
140+ if (fromTy == toTy)
141+ return 0 ;
142+
143+ // Skip conversion between int and float. Int16/int32 cases are lowered to
144+ // FMA.
145+ if (fromTy.isIntOrIndex () != toTy.isIntOrIndex ())
146+ return maxConvertCost;
147+
148+ if (fromTy.isIntOrIndex () && toTy.isIntOrIndex () &&
149+ fromTy.isUnsignedInteger () != toTy.isUnsignedInteger ())
150+ return fromTy.isUnsignedInteger () && fromTy.getIntOrFloatBitWidth () <
151+ toTy.getIntOrFloatBitWidth ()
152+ ? 1
153+ : maxConvertCost;
154+
155+ return fromTy.getIntOrFloatBitWidth () <= toTy.getIntOrFloatBitWidth ()
156+ ? 1
157+ : maxConvertCost;
158+ };
159+ auto minCost = maxConvertCost;
160+ auto optTypes = OperandTypesVector ();
161+ for (auto types : applicableTypes) {
162+ assert (types.size () == initElemTypes.size ());
163+ int accumulatedConvertCost = 0 ;
164+ for (int i = 0 ; i < initElemTypes.size (); ++i) {
165+ accumulatedConvertCost += calcConvertCost (initElemTypes[i], types[i]);
166+ }
167+ if (accumulatedConvertCost < minCost) {
168+ minCost = accumulatedConvertCost;
169+ optTypes = types;
170+ }
171+ }
172+ return optTypes;
173+ }
174+
175+ OperandTypesVector getOperandTypesForWmmaOp (mlir::PatternRewriter &rewriter,
176+ tt::DotOp dot) {
177+ Type f16 = rewriter.getF16Type ();
178+ Type f32 = rewriter.getF32Type ();
179+ Type bf16 = rewriter.getBF16Type ();
180+ Type i8 = rewriter.getIntegerType (8 );
181+ Type i32 = rewriter.getIntegerType (32 );
182+ SmallVector<OperandTypesVector> applicableTypes = {
183+ // clang-format off
184+ {f16 , f16 , f32 , f32 },
185+ {f16 , f16 , f16 , f16 },
186+ {bf16 , bf16 , f32 , f32 },
187+ {bf16 , bf16 , bf16 , bf16 },
188+ {i8 , i8 , i32 , i32 },
189+ // i4, i4, i32, i32 - is supported configuration
190+ // by WMMA instruction, but not supported by triton
191+ // clang-format on
192+ };
193+ return selectMatrixCoreOperandTypes (dot, applicableTypes);
194+ }
195+
115196/* *
116197 * @brief Convert layout and cast element type of a given tensor
117198 *
@@ -520,81 +601,71 @@ class BlockedToWMMA : public mlir::RewritePattern {
520601 mlir::LogicalResult
521602 matchAndRewrite (mlir::Operation *op,
522603 mlir::PatternRewriter &rewriter) const override {
604+ auto ctx = op->getContext ();
523605 auto dotOp = cast<tt::DotOp>(op);
524606
607+ Value a = dotOp.getA ();
608+ Value b = dotOp.getB ();
609+
525610 auto oldRetType = cast<RankedTensorType>(dotOp.getResult ().getType ());
526- if (!oldRetType.getEncoding () ||
527- !isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding ()))
611+ auto oldRetEncoding = oldRetType.getEncoding ();
612+ if (!oldRetEncoding || !isa<ttg::BlockedEncodingAttr>(oldRetEncoding))
613+ return failure ();
614+
615+ auto oldAType = cast<RankedTensorType>(a.getType ());
616+ auto oldBType = cast<RankedTensorType>(b.getType ());
617+ auto retShape = oldRetType.getShape ();
618+ auto aShape = oldAType.getShape ();
619+ auto bShape = oldBType.getShape ();
620+
621+ // check shape
622+ auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr ();
623+ auto rank = aShape.size ();
624+ if (aShape[rank - 2 ] % mnkDim[0 ] != 0 || // m
625+ bShape[rank - 1 ] % mnkDim[1 ] != 0 || // n
626+ aShape[rank - 1 ] % mnkDim[2 ] != 0 ) // k
528627 return failure ();
529628
530- if (!supportWMMA (dotOp))
629+ // get operand types
630+ auto operandTypes = getOperandTypesForWmmaOp (rewriter, dotOp);
631+ if (operandTypes.empty ())
531632 return failure ();
532633
533634 // get WMMA encoding for the given number of warps
534- auto retShape = oldRetType.getShape ();
535635 auto mod = op->getParentOfType <mlir::ModuleOp>();
536636 int numWarps = ttg::TritonGPUDialect::getNumWarps (mod);
537637
538- // operands
539- Value a = dotOp.getA ();
540- Value b = dotOp.getB ();
541- auto oldAType = cast<RankedTensorType>(a.getType ());
542- auto oldBType = cast<RankedTensorType>(b.getType ());
543- auto ctx = oldAType.getContext ();
544-
545638 AMDWmmaEncodingAttr wmmaEnc;
546639
547- auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr ();
548640 auto warpsPerTile = warpsPerTileWMMA (dotOp, retShape, numWarps);
549- // Not supported yet
550- // if (retShape[0] < warpsPerTile[0] * mnkDim[0] || retShape[1] <
551- // warpsPerTile[1] * mnkDim[1])
552- // return failure();
553- auto CTALayout = ttg::getCTALayout (oldRetType.getEncoding ());
554- wmmaEnc = AMDWmmaEncodingAttr::get (oldRetType.getContext (), warpsPerTile,
555- CTALayout);
556-
557- Type wmmaAccType;
558- auto oldRetElemType = oldRetType.getElementType ();
559- auto aElemType = oldAType.getElementType ();
560- auto bElemType = oldBType.getElementType ();
561- if (oldRetElemType.isIntOrIndex ()) {
562- wmmaAccType = rewriter.getIntegerType (32 );
563- } else if (isa<mlir::Float16Type, mlir::BFloat16Type>(oldRetElemType) &&
564- aElemType == oldRetElemType) {
565- wmmaAccType = oldRetElemType;
566- } else if (isa<mlir::FloatType>(oldRetElemType) &&
567- aElemType.getIntOrFloatBitWidth () < 16 ) {
568- aElemType = rewriter.getF16Type ();
569- bElemType = rewriter.getF16Type ();
570- wmmaAccType = rewriter.getF16Type ();
571- } else {
572- wmmaAccType = rewriter.getF32Type ();
573- }
574641
575- auto newRetType = RankedTensorType::get (retShape, wmmaAccType, wmmaEnc);
642+ auto CTALayout = ttg::getCTALayout (oldRetEncoding);
643+ wmmaEnc = AMDWmmaEncodingAttr::get (ctx, warpsPerTile, CTALayout);
644+
645+ auto newRetType = RankedTensorType::get (retShape, operandTypes[3 ], wmmaEnc);
576646
577647 // convert accumulator
578648 auto oldAcc = dotOp.getOperand (2 );
579- auto newAcc = convertAndCastTensor (rewriter, oldAcc, wmmaEnc, wmmaAccType);
649+ auto newAcc =
650+ convertAndCastTensor (rewriter, oldAcc, wmmaEnc, operandTypes[2 ]);
580651
581652 auto newAType = RankedTensorType::get (
582- oldAType. getShape (), aElemType ,
653+ aShape, operandTypes[ 0 ] ,
583654 ttg::DotOperandEncodingAttr::get (ctx, 0 , wmmaEnc, mnkDim[2 ]));
584655 auto newBType = RankedTensorType::get (
585- oldBType. getShape (), bElemType ,
656+ bShape, operandTypes[ 1 ] ,
586657 ttg::DotOperandEncodingAttr::get (ctx, 1 , wmmaEnc, mnkDim[2 ]));
587658
588- Value castedA =
589- convertAndCastTensor (rewriter, a, newAType. getEncoding (), aElemType );
590- Value castedB =
591- convertAndCastTensor (rewriter, b, newBType. getEncoding (), bElemType );
659+ Value castedA = convertAndCastTensor (rewriter, a, newAType. getEncoding (),
660+ operandTypes[ 0 ] );
661+ Value castedB = convertAndCastTensor (rewriter, b, newBType. getEncoding (),
662+ operandTypes[ 1 ] );
592663 auto newDot = rewriter.create <tt::DotOp>(
593664 dotOp.getLoc (), newRetType, castedA, castedB, newAcc,
594665 dotOp.getInputPrecision (), dotOp.getMaxNumImpreciseAcc ());
595666
596- Value dotOutput = convertAndCastTensor (
597- rewriter, newDot, oldRetType.getEncoding (), oldRetElemType );
667+ Value dotOutput = convertAndCastTensor (rewriter, newDot, oldRetEncoding,
668+ oldRetType.getElementType () );
598669 rewriter.replaceOp (op, dotOutput);
599670 return success ();
600671 }
0 commit comments