@@ -13,80 +13,6 @@ using namespace mlir::triton::gpu;
1313using namespace mlir ::triton::gpu::intel;
1414
1515namespace {
16- SmallVector<Value> convertMxfp4x2ToBf16x2 (RewriterBase &rewriter, Location loc,
17- ArrayRef<Value> values) {
18- auto b = TritonLLVMOpBuilder (loc, rewriter);
19- SmallVector<Value> results;
20- for (auto v : values) {
21- auto em0 = b.and_ (v, b.i8_val (0x7 ));
22- auto em1 = b.and_ (v, b.i8_val (0x70 ));
23- Value v0 =
24- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (6 )),
25- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
26- Value v1 =
27- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (2 )),
28- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
29- // Three cases:
30- // 1) x is normal and non-zero: Correct bias
31- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
32- b.add (v0, b.i16_val ((127 - 1 ) << 7 )), v0);
33- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
34- b.add (v1, b.i16_val ((127 - 1 ) << 7 )), v1);
35- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
36- // bf16
37- v0 = b.bitcast (
38- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
39- b.or_ (b.i16_val (16128 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
40- bf16_ty);
41- v1 = b.bitcast (
42- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
43- b.or_ (b.i16_val (16128 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
44- bf16_ty);
45- // 3) x is zero, nothing to do
46- results.push_back (v0);
47- results.push_back (v1);
48- }
49- return results;
50- }
51-
52- SmallVector<Value> convertMxfp4x2ToFp16x2 (RewriterBase &rewriter, Location loc,
53- ArrayRef<Value> values) {
54- auto b = TritonLLVMOpBuilder (loc, rewriter);
55- SmallVector<Value> results;
56- for (auto v : values) {
57- auto em0 = b.and_ (v, b.i8_val (0x7 ));
58- auto em1 = b.and_ (v, b.i8_val (0x70 ));
59- // FP16 bits: sign = 1, exponent = 5, mantissa = 10
60- Value v0 =
61- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (10 - 1 )),
62- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
63- Value v1 =
64- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (10 - 1 - 4 )),
65- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
66-
67- // Three cases:
68- // 1) x is normal and non-zero: Correct bias
69- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
70- b.add (v0, b.i16_val ((15 - 1 ) << 10 )), v0);
71- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
72- b.add (v1, b.i16_val ((15 - 1 ) << 10 )), v1);
73-
74- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
75- v0 = b.bitcast (
76- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
77- b.or_ (b.i16_val (0x3800 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
78- f16_ty);
79- v1 = b.bitcast (
80- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
81- b.or_ (b.i16_val (0x3800 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
82- f16_ty);
83- // 3) x is zero, nothing to do
84- results.push_back (v0);
85- results.push_back (v1);
86- }
87- return results;
88- }
89-
9016class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern <Fp4ToFpOp> {
9117public:
9218 Fp4ToFpOpPattern (LLVMTypeConverter &typeConverter, PatternBenefit benefit)
@@ -96,21 +22,51 @@ class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9622 matchAndRewrite (Fp4ToFpOp op, OpAdaptor adaptor,
9723 ConversionPatternRewriter &rewriter) const override {
9824 Location loc = op.getLoc ();
99- auto *ctx = op.getContext ();
10025 Type elemType = op.getType ().getElementType ();
10126 assert (elemType == f16_ty || elemType == bf16_ty);
102- bool toFp16 = elemType == f16_ty;
103-
104- SmallVector<Value> xVals =
105- unpackLLElements (loc, adaptor.getSrc (), rewriter);
106- xVals = toFp16 ? convertMxfp4x2ToFp16x2 (rewriter, loc, xVals)
107- : convertMxfp4x2ToBf16x2 (rewriter, loc, xVals);
10827
109- Value result =
110- packLLElements (loc, getTypeConverter (), xVals, rewriter, op.getType ());
111- rewriter.replaceOp (op, result);
28+ SmallVector<Value> results;
29+ {
30+ SmallVector<Value> xVals =
31+ unpackLLElements (loc, adaptor.getSrc (), rewriter);
32+ convertMxfp4x2ToFloat (rewriter, loc, xVals, results,
33+ elemType == f16_ty ? f16_ty : bf16_ty);
34+ }
35+ rewriter.replaceOp (op, packLLElements (loc, getTypeConverter (), results,
36+ rewriter, op.getType ()));
11237 return success ();
11338 }
39+
40+ private:
41+ static void convertMxfp4x2ToFloat (RewriterBase &rewriter, Location loc,
42+ SmallVector<Value> &values,
43+ SmallVector<Value> &results,
44+ FloatType floatTy) {
45+ assert (results.empty () && !values.empty ());
46+
47+ Value table;
48+ { // Create a constant vector containing all the possible values
49+ auto vecTy = VectorType::get ({16 }, floatTy);
50+ SmallVector<Attribute, 16 > values;
51+ for (double v : {0 ., 0.5 , 1 ., 1.5 , 2 ., 3 ., 4 ., 6 ., -0 ., -0.5 , -1 ., -1.5 ,
52+ -2 ., -3 ., -4 ., -6 .})
53+ values.push_back (rewriter.getFloatAttr (floatTy, v));
54+ table = rewriter.create <LLVM::ConstantOp>(
55+ loc, vecTy, DenseElementsAttr::get (vecTy, values));
56+ }
57+
58+ TritonLLVMOpBuilder b (loc, rewriter);
59+ Value i8_4 = b.i8_val (4 );
60+ Value i8_15 = b.i8_val (15 );
61+ results.reserve (values.size () * 2 );
62+ for (Value v : values) {
63+ // The first and last 4 bits are the values indices in the table
64+ Value idx1 = b.and_ (v, i8_15);
65+ Value idx2 = b.lshr (v, i8_4);
66+ results.push_back (b.extract_element (table, idx1));
67+ results.push_back (b.extract_element (table, idx2));
68+ }
69+ }
11470};
11571} // anonymous namespace
11672
0 commit comments