@@ -535,6 +535,67 @@ void BIImport::fixSPIRFunctionsReturnType(Module& M)
535535 F->eraseFromParent ();
536536}
537537
538+ // The built-in definition returns i32, however, at this point the function call
539+ // that has been added for round_to_tf32() call returns a float (as the orig
540+ // matrix type was float). So we need to: 1) Change the return type of the
541+ // function declaration to int so that it matches the builtin definition; 2)
542+ // Cast the returned value of the function back to float so that the previous
543+ // users of the return value are happy.
544+ void fixRoundToTF32ReturnType (Module &M) {
545+ SmallPtrSet<Function *, 8 > funcsToRemove;
546+ for (auto &F : M) {
547+ if (!F.isDeclaration ())
548+ continue ;
549+ auto FuncName = F.getName ();
550+
551+ if (!FuncName.contains (" OpRoundFToTF32INTEL" ) ||
552+ FuncName.contains (" _old" ))
553+ continue ;
554+ if (!F.getReturnType ()->isFloatTy ())
555+ continue ;
556+
557+ FunctionType *FT = F.getFunctionType ();
558+
559+ FunctionType *NewFT = FunctionType::get (
560+ Type::getInt32Ty (M.getContext ()), FT->params (), false );
561+ auto *NewF =
562+ Function::Create (NewFT, F.getLinkage (), FuncName + " .cloned" , M);
563+
564+ SmallPtrSet<CallInst *, 16 > Calls;
565+
566+ for (auto user : F.users ())
567+ if (CallInst *CI = dyn_cast<CallInst>(user))
568+ Calls.insert (CI);
569+
570+ for (auto CI : Calls) {
571+ IRBuilder<> builder (CI);
572+
573+ SmallVector<Value *, 4 > Args;
574+ for (auto &Arg : CI->args ())
575+ Args.push_back (Arg);
576+
577+ auto *newCall = builder.CreateCall (NewF, Args);
578+ newCall->setCallingConv (CI->getCallingConv ());
579+ newCall->setAttributes (CI->getAttributes ());
580+ // Convert the value back so that previous users of
581+ // the return value are happy
582+ auto *converted = builder.CreateBitCast (newCall, CI->getType ());
583+
584+ CI->replaceAllUsesWith (converted);
585+ CI->eraseFromParent ();
586+ }
587+
588+ std::string originalName = FuncName.str ();
589+ F.setName (FuncName + " _old" );
590+ NewF->setName (originalName);
591+
592+ funcsToRemove.insert (&F);
593+ }
594+
595+ for (auto *F : funcsToRemove)
596+ F->eraseFromParent ();
597+ }
598+
538599// Older Clang versions generate invalid bitcast instructions for explicit
539600// C-style casts with specified address space. For example:
540601// %0 = bitcast i8 addrspace(1)* %mem to i32 addrspace(4)*
@@ -594,6 +655,7 @@ bool BIImport::runOnModule(Module& M)
594655 }
595656
596657 fixSPIRFunctionsReturnType (M);
658+ fixRoundToTF32ReturnType (M);
597659
598660 for (auto & F : M)
599661 {
0 commit comments