Skip to content

Commit 55a2172

Browse files
chengjunluyudongsiwhitneywhtsang
authored
Add f32 RTNE to tf32 in DPAS (#3803)
SPIRV extension `_Z25__spirv_RoundFToTF32INTELf` for fp32 to tf32. --------- Co-authored-by: Si, Yudong <yudong.si@intel.com> Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent d1121e2 commit 55a2172

File tree

5 files changed

+80
-6
lines changed

5 files changed

+80
-6
lines changed

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
7272
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=1}>
7373

7474
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
75-
// CHECK-LABEL: dot_f32_tf32_tf32_f32_1
75+
// CHECK-LABEL: llvm.func spir_kernelcc @dot_f32_tf32_tf32_f32_1(
76+
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f32, f32, f32, f32)>, %[[B:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
77+
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 32, 1, 1>} {
7678
tt.func @dot_f32_tf32_tf32_f32_1(%a: tensor<8x8xf32, #dot_operand_a>, %b: tensor<8x16xf32, #dot_operand_b>, %c: tensor<8x16xf32, #dpas>) {
77-
// CHECK: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
79+
// COM: To simplify, only check RTNE and its usage for the last element of A, B, C
80+
// CHECK %[[A_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
81+
// CHECK %[[A_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[A_LAST_VAL]])
82+
// CHECK %[[A_0:.*]] = llvm.insertelement %[[A_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<4xf32>
83+
// CHECK %[[B_LAST_VAL:.*]] = llvm.extractvalue %[[B]][7]
84+
// CHECK %[[B_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[B_LAST_VAL]])
85+
// CHECK %[[B_0:.*]] = llvm.insertelement %[[B_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
86+
// CHECK %[[C_LAST_VAL:.*]] = llvm.extractvalue %[[C]][7]
87+
// CHECK %[[C_0:.*]] = llvm.insertelement %[[C_LAST_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
88+
// CHECK : llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_(%[[A_0]], %[[B_0]], %[[C_0]]) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
7889
%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<8x8xf32, #dot_operand_a> * tensor<8x16xf32, #dot_operand_b> -> tensor<8x16xf32, #dpas>
7990
tt.return
8091
}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include "mlir/IR/EnumAttr.td"
2020
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2121
include "mlir/Interfaces/SideEffectInterfaces.td"
2222
include "mlir/IR/OpAsmInterface.td"
23+
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
2324

2425
//===----------------------------------------------------------------------===//
2526
// TritonGEN op definitions
@@ -313,4 +314,19 @@ def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
313314
}];
314315
}
315316

317+
def TritonGEN_FToTf32Op
318+
: TritonGEN_Op<"f_to_tf32", [SameOperandsAndResultType]> {
319+
let summary = "Rounding instruction from float to tensor float (TF32) data format";
320+
321+
let description = [{
322+
The op converts value numerically from
323+
a 32-bit floating point type to TF32 with rounding to the nearest even.
324+
}];
325+
326+
let arguments = (ins F32:$val);
327+
let results = (outs F32:$res);
328+
let assemblyFormat = [{
329+
$val attr-dict `:` type($val)
330+
}];
331+
}
316332
#endif // TRITONGEN_OPS

third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class SmallVectorBuffer : public std::streambuf {
107107

108108
static SPIRV::TranslatorOpts getSPIRVOopts() {
109109
SPIRV::TranslatorOpts SPIRVOpts;
110-
static constexpr std::array<SPIRV::ExtensionID, 12> AllowedExtensions{
110+
static constexpr std::array<SPIRV::ExtensionID, 13> AllowedExtensions{
111111
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
112112
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
113113
SPIRV::ExtensionID::SPV_INTEL_arithmetic_fence,
@@ -116,6 +116,7 @@ static SPIRV::TranslatorOpts getSPIRVOopts() {
116116
SPIRV::ExtensionID::SPV_INTEL_kernel_attributes,
117117
SPIRV::ExtensionID::SPV_INTEL_memory_access_aliasing,
118118
SPIRV::ExtensionID::SPV_INTEL_subgroups,
119+
SPIRV::ExtensionID::SPV_INTEL_tensor_float32_rounding,
119120
SPIRV::ExtensionID::SPV_INTEL_unstructured_loop_controls,
120121
SPIRV::ExtensionID::SPV_INTEL_vector_compute,
121122
SPIRV::ExtensionID::SPV_KHR_bit_instructions,

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,31 @@ struct TritonSubGroupBlockWriteLowering
699699
}
700700
};
701701

702+
struct TritonFToTf32OpLowering
703+
: public ConvertOpToLLVMPattern<TritonGEN::FToTf32Op> {
704+
using ConvertOpToLLVMPattern<TritonGEN::FToTf32Op>::ConvertOpToLLVMPattern;
705+
706+
LogicalResult
707+
matchAndRewrite(TritonGEN::FToTf32Op op, OpAdaptor adaptor,
708+
ConversionPatternRewriter &rewriter) const override {
709+
MLIRContext *ctx = rewriter.getContext();
710+
Location loc = op->getLoc();
711+
auto b = TritonLLVMOpBuilder(loc, rewriter);
712+
713+
Value value = op->getOperand(0);
714+
SmallVector<Type> argTypes{f32_ty};
715+
SmallVector<Value> args{value};
716+
717+
const StringLiteral funcName = "_Z25__spirv_RoundFToTF32INTELf";
718+
auto retType = f32_ty;
719+
auto callOp = intel::createDeviceFunctionCall(
720+
rewriter, funcName, retType, {argTypes}, {args}, {},
721+
intel::noUnwindWillReturnAttrs);
722+
rewriter.replaceOp(op, callOp);
723+
return success();
724+
}
725+
};
726+
702727
} // namespace
703728

704729
//===----------------------------------------------------------------------===//
@@ -760,7 +785,8 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
760785
.add<TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
761786
TritonMatrix2DBlockStoreLowering,
762787
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
763-
TritonSubGroupBlockWriteLowering>(converter);
788+
TritonSubGroupBlockWriteLowering, TritonFToTf32OpLowering>(
789+
converter);
764790
}
765791

766792
void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,16 +310,21 @@ class DotOpDPASConversionHelper {
310310
size_t rank = repCluster.size();
311311
unsigned repClusterOuter = 0u;
312312
unsigned repClusterInner = 0u;
313+
bool isOperandA = false;
314+
bool isOperandB = false;
315+
bool isFToTF32Enabled = false;
313316
switch (opIdx) {
314317
case DpasEncodingAttr::OpIdx::OperandA:
315318
// operand A
316319
repClusterOuter = repCluster[rank - 2];
317320
repClusterInner = 1;
321+
isOperandA = true;
318322
break;
319323
case DpasEncodingAttr::OpIdx::OperandB:
320324
// operand B
321325
repClusterInner = 1;
322326
repClusterOuter = repCluster[rank - 1];
327+
isOperandB = true;
323328
break;
324329
case DpasEncodingAttr::OpIdx::OperandC:
325330
// operand C
@@ -333,6 +338,11 @@ class DotOpDPASConversionHelper {
333338
totalElems /
334339
((batch * outer * inner) * (repClusterOuter * repClusterInner));
335340
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
341+
// TODO: IGC bug, Update isFToTF32Enabled as follows once issue #3870 is
342+
// fixed. isFToTF32Enabled = elemTy.isFloat(32) && (isOperandA ||
343+
// isOperandB)
344+
isFToTF32Enabled = elemTy.isFloat(32) &&
345+
((rank == 3) ? isOperandA : (isOperandA || isOperandB));
336346

337347
auto tb = TritonLLVMOpBuilder(loc, rewriter);
338348
int offset = 0;
@@ -344,8 +354,18 @@ class DotOpDPASConversionHelper {
344354
for (int repInner = 0; repInner < repClusterInner; ++repInner) {
345355
Value matVal = rewriter.create<LLVM::UndefOp>(loc, dotOpTy);
346356
for (int k = 0; k < numElemsPerOperand; ++k) {
347-
matVal = tb.insert_element(dotOpTy, matVal, elems[offset++],
348-
tb.i32_val(k));
357+
if (isFToTF32Enabled) {
358+
Value f32Val = elems[offset++];
359+
auto t32Val =
360+
rewriter.create<TritonGEN::FToTf32Op>(loc, f32Val)
361+
.getResult();
362+
matVal =
363+
tb.insert_element(dotOpTy, matVal, t32Val, tb.i32_val(k));
364+
365+
} else {
366+
matVal = tb.insert_element(dotOpTy, matVal, elems[offset++],
367+
tb.i32_val(k));
368+
}
349369
}
350370
vals[{b, i * repClusterOuter + repOuter,
351371
j * repClusterInner + repInner}] = matVal;

0 commit comments

Comments
 (0)