Skip to content

Commit 8440909

Browse files
arnamoy10igcbot
authored andcommitted
Add round_to_tf32() lowering support.
This patch adds round_to_tf32() lowering support in IGC.
1 parent 36d40c2 commit 8440909

File tree

8 files changed

+124
-1
lines changed

8 files changed

+124
-1
lines changed

IGC/AdaptorOCL/SPIRV/SPIRVInternal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ _SPIRV_OP(OpSatConvertUToS)
443443
_SPIRV_OP(OpSatConvertSToU)
444444
_SPIRV_OP(OpConvertFToBF16INTEL)
445445
_SPIRV_OP(OpConvertBF16ToFINTEL)
446+
// Rounding builtins
447+
_SPIRV_OP(OpRoundFToTF32INTEL)
446448
// SPV_INTEL_arithmetic_fence
447449
_SPIRV_OP(OpArithmeticFenceINTEL)
448450
// Arithmetic Instructions

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,7 @@ _SPIRV_OP(Any)
10571057
_SPIRV_OP(All)
10581058
_SPIRV_OP(ConvertFToBF16INTEL)
10591059
_SPIRV_OP(ConvertBF16ToFINTEL)
1060+
_SPIRV_OP(RoundFToTF32INTEL)
10601061
_SPIRV_OP(ArithmeticFenceINTEL)
10611062
_SPIRV_OP(BitReverse)
10621063
#undef _SPIRV_OP

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ _SPIRV_OP(TypeTokenINTEL, 6113)
515515
//_SPIRV_OP(DebugInfoModuleINTEL, 6114)
516516
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
517517
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
518+
_SPIRV_OP(RoundFToTF32INTEL, 6426)
518519
// SPV_INTEL_matrix
519520
//_SPIRV_OP(TypeJointMatrixINTEL_OLD, 6119) Replaced by 6184
520521
_SPIRV_OP(TypeJointMatrixINTEL, 6184)

IGC/BiFModule/Headers/spirv.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3697,6 +3697,13 @@ float4 SPIRV_OVERLOADABLE SPIRV_BUILTIN(ConvertBF16ToFINTEL, _v4i16, )(short4 x
36973697
float8 SPIRV_OVERLOADABLE SPIRV_BUILTIN(ConvertBF16ToFINTEL, _v8i16, )(short8 x);
36983698
float16 SPIRV_OVERLOADABLE SPIRV_BUILTIN(ConvertBF16ToFINTEL, _v16i16, )(short16 x);
36993699

3700+
int SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _f32, )(float x);
3701+
int2 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v2f32, )(float2 x);
3702+
int3 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v3f32, )(float3 x);
3703+
int4 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v4f32, )(float4 x);
3704+
int8 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v8f32, )(float8 x);
3705+
int16 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v16f32, )(float16 x);
3706+
37003707
#if (__OPENCL_C_VERSION__ >= CL_VERSION_2_0)
37013708
private void* SPIRV_OVERLOADABLE SPIRV_BUILTIN(GenericCastToPtrExplicit, _p0i8_p4i8_i32, _ToPrivate)(generic char *Pointer, int Storage);
37023709
local void* SPIRV_OVERLOADABLE SPIRV_BUILTIN(GenericCastToPtrExplicit, _p3i8_p4i8_i32, _ToLocal)(generic char *Pointer, int Storage);

IGC/BiFModule/Implementation/conversions.cl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,36 @@ float16 SPIRV_OVERLOADABLE SPIRV_BUILTIN(ConvertBF16ToFINTEL, _v16i16, )(short1
925925
return __builtin_IB_bftof_16(Value);
926926
}
927927

928+
int SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _f32, )(float Value)
929+
{
930+
return __builtin_IB_ftotf32_1(Value);
931+
}
932+
933+
int2 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v2f32, )(float2 Value)
934+
{
935+
return __builtin_IB_ftotf32_2(Value);
936+
}
937+
938+
int3 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v3f32, )(float3 Value)
939+
{
940+
return __builtin_IB_ftotf32_3(Value);
941+
}
942+
943+
int4 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v4f32, )(float4 Value)
944+
{
945+
return __builtin_IB_ftotf32_4(Value);
946+
}
947+
948+
int8 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v8f32, )(float8 Value)
949+
{
950+
return __builtin_IB_ftotf32_8(Value);
951+
}
952+
953+
int16 SPIRV_OVERLOADABLE SPIRV_BUILTIN(RoundFToTF32INTEL, _v16f32, )(float16 Value)
954+
{
955+
return __builtin_IB_ftotf32_16(Value);
956+
}
957+
928958
/*
929959
// Next is all Scalar types with Rounding modes [RTE,RTZ,RTN,RTP] and Sat
930960
//

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20392,7 +20392,18 @@ void EmitPass::emitfcvt(llvm::GenIntrinsicInst* GII)
2039220392
}
2039320393
else if (id == GenISAIntrinsic::GenISA_ftotf32) {
2039420394
tDst = m_currShader->GetNewAlias(dst, ISA_TYPE_UD, 0, 0);
20395-
tSrc = src;
20395+
// Does not support immediate source of type float, therefore we
20396+
// need a temporary "general" variable and copy the immediate
20397+
// value to that temporary variable first. Then we can use this
20398+
// temporary as an operand of fcvt.
20399+
if (src->IsImmediate()) {
20400+
CVariable *tfSrc = m_currShader->GetNewVariable(
20401+
1, ISA_TYPE_F, EALIGN_GRF, "tmp_cvt");
20402+
m_encoder->Copy(tfSrc, src);
20403+
tSrc = tfSrc;
20404+
} else {
20405+
tSrc = src;
20406+
}
2039620407
}
2039720408
else {
2039820409
IGC_ASSERT_EXIT_MESSAGE(0, "Something wrong in cvt!");

IGC/Compiler/Optimizer/BuiltInFuncImport.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,67 @@ void BIImport::fixSPIRFunctionsReturnType(Module& M)
535535
F->eraseFromParent();
536536
}
537537

538+
// The built-in definition returns i32, however, at this point the function call
539+
// that has been added for round_to_tf32() call returns a float (as the orig
540+
// matrix type was float). So we need to: 1) Change the return type of the
541+
// function declaration to int so that it matches the builtin definition; 2)
542+
// Cast the returned value of the function back to float so that the previous
543+
// users of the return value are happy.
544+
void fixRoundToTF32ReturnType(Module &M) {
545+
SmallPtrSet<Function *, 8> funcsToRemove;
546+
for (auto &F : M) {
547+
if (!F.isDeclaration())
548+
continue;
549+
auto FuncName = F.getName();
550+
551+
if (!FuncName.contains("OpRoundFToTF32INTEL") ||
552+
FuncName.contains("_old"))
553+
continue;
554+
if (!F.getReturnType()->isFloatTy())
555+
continue;
556+
557+
FunctionType *FT = F.getFunctionType();
558+
559+
FunctionType *NewFT = FunctionType::get(
560+
Type::getInt32Ty(M.getContext()), FT->params(), false);
561+
auto *NewF =
562+
Function::Create(NewFT, F.getLinkage(), FuncName + ".cloned", M);
563+
564+
SmallPtrSet<CallInst *, 16> Calls;
565+
566+
for (auto user : F.users())
567+
if (CallInst *CI = dyn_cast<CallInst>(user))
568+
Calls.insert(CI);
569+
570+
for (auto CI : Calls) {
571+
IRBuilder<> builder(CI);
572+
573+
SmallVector<Value *, 4> Args;
574+
for (auto &Arg : CI->args())
575+
Args.push_back(Arg);
576+
577+
auto *newCall = builder.CreateCall(NewF, Args);
578+
newCall->setCallingConv(CI->getCallingConv());
579+
newCall->setAttributes(CI->getAttributes());
580+
// Convert the value back so that previous users of
581+
// the return value are happy
582+
auto *converted = builder.CreateBitCast(newCall, CI->getType());
583+
584+
CI->replaceAllUsesWith(converted);
585+
CI->eraseFromParent();
586+
}
587+
588+
std::string originalName = FuncName.str();
589+
F.setName(FuncName + "_old");
590+
NewF->setName(originalName);
591+
592+
funcsToRemove.insert(&F);
593+
}
594+
595+
for (auto *F : funcsToRemove)
596+
F->eraseFromParent();
597+
}
598+
538599
// Older Clang versions generate invalid bitcast instructions for explicit
539600
// C-style casts with specified address space. For example:
540601
// %0 = bitcast i8 addrspace(1)* %mem to i32 addrspace(4)*
@@ -594,6 +655,7 @@ bool BIImport::runOnModule(Module& M)
594655
}
595656

596657
fixSPIRFunctionsReturnType(M);
658+
fixRoundToTF32ReturnType(M);
597659

598660
for (auto& F : M)
599661
{

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,15 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
10941094
fillValue = builder.CreateLoad(vectorElementType, fillValue);
10951095
}
10961096

1097+
// For TF32 type, the slice has a type of i32, however, the value we are
1098+
// filling with has a type of float. So we need a bitcast.
1099+
bool isTF32 = (desc.isFloating) && (desc.bitWidth == 32);
1100+
if (isTF32) {
1101+
fillValue = builder.CreateBitCast(
1102+
fillValue, Type::getIntNTy(builder.getContext(),
1103+
getResolvedVectorElemSize(matTy)));
1104+
}
1105+
10971106
Value *slice = fillValue;
10981107

10991108
if (IGCLLVM::FixedVectorType *ty =

0 commit comments

Comments
 (0)