@@ -85,6 +85,20 @@ class SPIRVRegularizeLLVM : public ModulePass {
8585 // / @spirv.llvm_memset_* and replace it with @llvm.memset.
8686 void lowerMemset (MemSetInst *MSI);
8787
88+ // / No SPIR-V counterpart for @llvm.fshl.i* intrinsic. It will be lowered
89+ // / to a newly generated @spirv.llvm_fshl_i* function.
90+ // / Conceptually, FSHL:
91+ // / 1. concatenates the ints, the first one being the more significant;
92+ // / 2. performs a left shift-rotate on the resulting doubled-sized int;
93+ // / 3. returns the most significant bits of the shift-rotate result,
94+ // / the number of bits being equal to the size of the original integers.
95+ // / The actual implementation algorithm will be slightly different to speed
96+ // / things up.
97+ void lowerFunnelShiftLeft (IntrinsicInst *FSHLIntrinsic);
98+ void buildFunnelShiftLeftFunc (Function *FSHLFunc);
99+
100+ static std::string lowerLLVMIntrinsicName (IntrinsicInst *II);
101+
88102 static char ID;
89103
90104private:
@@ -94,17 +108,22 @@ class SPIRVRegularizeLLVM : public ModulePass {
94108
95109char SPIRVRegularizeLLVM::ID = 0 ;
96110
97- void SPIRVRegularizeLLVM::lowerMemset (MemSetInst *MSI) {
98- if (isa<Constant>(MSI->getValue ()) && isa<ConstantInt>(MSI->getLength ()))
99- return ; // To be handled in LLVMToSPIRV::transIntrinsicInst
100- Function *IntrinsicFunc = MSI->getCalledFunction ();
111+ std::string SPIRVRegularizeLLVM::lowerLLVMIntrinsicName (IntrinsicInst *II) {
112+ Function *IntrinsicFunc = II->getCalledFunction ();
101113 assert (IntrinsicFunc && " Missing function" );
102114 std::string FuncName = IntrinsicFunc->getName ().str ();
103115 std::replace (FuncName.begin (), FuncName.end (), ' .' , ' _' );
104116 FuncName = " spirv." + FuncName;
117+ return FuncName;
118+ }
119+
120+ void SPIRVRegularizeLLVM::lowerMemset (MemSetInst *MSI) {
121+ if (isa<Constant>(MSI->getValue ()) && isa<ConstantInt>(MSI->getLength ()))
122+ return ; // To be handled in LLVMToSPIRV::transIntrinsicInst
123+
124+ std::string FuncName = lowerLLVMIntrinsicName (MSI);
105125 if (MSI->isVolatile ())
106126 FuncName += " .volatile" ;
107-
108127 // Redirect @llvm.memset.* call to @spirv.llvm_memset_*
109128 Function *F = M->getFunction (FuncName);
110129 if (F) {
@@ -137,6 +156,75 @@ void SPIRVRegularizeLLVM::lowerMemset(MemSetInst *MSI) {
137156 return ;
138157}
139158
159+ void SPIRVRegularizeLLVM::buildFunnelShiftLeftFunc (Function *FSHLFunc) {
160+ if (!FSHLFunc->empty ())
161+ return ;
162+
163+ auto *IntTy = dyn_cast<IntegerType>(FSHLFunc->getReturnType ());
164+ assert (IntTy && " llvm.fshl: expected an integer return type" );
165+ assert (FSHLFunc->arg_size () == 3 && " llvm.fshl: expected 3 arguments" );
166+ for (Argument &Arg : FSHLFunc->args ())
167+ assert (Arg.getType ()->getTypeID () == IntTy->getTypeID () &&
168+ " llvm.fshl: mismatched return type and argument types" );
169+
170+ // Our function will require 3 basic blocks; the purpose of each will be
171+ // clarified below.
172+ auto *CondBB = BasicBlock::Create (M->getContext (), " cond" , FSHLFunc);
173+ auto *RotateBB =
174+ BasicBlock::Create (M->getContext (), " rotate" , FSHLFunc); // Main logic
175+ auto *PhiBB = BasicBlock::Create (M->getContext (), " phi" , FSHLFunc);
176+
177+ IRBuilder<> Builder (CondBB);
178+ // If the number of bits to rotate for is divisible by the bitsize,
179+ // the shift becomes useless, and we should bypass the main logic in that
180+ // case.
181+ unsigned BitWidth = IntTy->getIntegerBitWidth ();
182+ ConstantInt *BitWidthConstant = Builder.getInt ({BitWidth, BitWidth});
183+ auto *RotateModVal =
184+ Builder.CreateURem (/* Rotate*/ FSHLFunc->getArg (2 ), BitWidthConstant);
185+ ConstantInt *ZeroConstant = Builder.getInt ({BitWidth, 0 });
186+ auto *CheckRotateModIfZero = Builder.CreateICmpEQ (RotateModVal, ZeroConstant);
187+ Builder.CreateCondBr (CheckRotateModIfZero, /* True*/ PhiBB,
188+ /* False*/ RotateBB);
189+
190+ // Build the actual funnel shift rotate logic.
191+ Builder.SetInsertPoint (RotateBB);
192+ // Shift the more significant number left, the "rotate" number of bits
193+ // will be 0-filled on the right as a result of this regular shift.
194+ auto *ShiftLeft = Builder.CreateShl (FSHLFunc->getArg (0 ), RotateModVal);
195+ // We want the "rotate" number of the second int's MSBs to occupy the
196+ // rightmost "0 space" left by the previous operation. Therefore,
197+ // subtract the "rotate" number from the integer bitsize...
198+ auto *SubRotateVal = Builder.CreateSub (BitWidthConstant, RotateModVal);
199+ // ...and right-shift the second int by this number, zero-filling the MSBs.
200+ auto *ShiftRight = Builder.CreateLShr (FSHLFunc->getArg (1 ), SubRotateVal);
201+ // A simple binary addition of the shifted ints yields the final result.
202+ auto *FunnelShiftRes = Builder.CreateOr (ShiftLeft, ShiftRight);
203+ Builder.CreateBr (PhiBB);
204+
205+ // PHI basic block. If no actual rotate was required, return the first, more
206+ // significant int. E.g. for 32-bit integers, it's equivalent to concatenating
207+ // the 2 ints and taking 32 MSBs.
208+ Builder.SetInsertPoint (PhiBB);
209+ PHINode *Phi = Builder.CreatePHI (IntTy, 0 );
210+ Phi->addIncoming (FunnelShiftRes, RotateBB);
211+ Phi->addIncoming (FSHLFunc->getArg (0 ), CondBB);
212+ Builder.CreateRet (Phi);
213+ }
214+
215+ void SPIRVRegularizeLLVM::lowerFunnelShiftLeft (IntrinsicInst *FSHLIntrinsic) {
216+ // Get a separate function - otherwise, we'd have to rework the CFG of the
217+ // current one. Then simply replace the intrinsic uses with a call to the new
218+ // function.
219+ FunctionType *FSHLFuncTy = FSHLIntrinsic->getFunctionType ();
220+ Type *FSHLRetTy = FSHLFuncTy->getReturnType ();
221+ const std::string FuncName = lowerLLVMIntrinsicName (FSHLIntrinsic);
222+ Function *FSHLFunc =
223+ getOrCreateFunction (M, FSHLRetTy, FSHLFuncTy->params (), FuncName);
224+ buildFunnelShiftLeftFunc (FSHLFunc);
225+ FSHLIntrinsic->setCalledFunction (FSHLFunc);
226+ }
227+
140228bool SPIRVRegularizeLLVM::runOnModule (Module &Module) {
141229 M = &Module;
142230 Ctx = &M->getContext ();
@@ -170,8 +258,11 @@ bool SPIRVRegularizeLLVM::regularize() {
170258 Function *CF = Call->getCalledFunction ();
171259 if (CF && CF->isIntrinsic ()) {
172260 removeFnAttr (Call, Attribute::NoUnwind);
173- if (auto *MSI = dyn_cast<MemSetInst>(Call))
261+ auto *II = cast<IntrinsicInst>(Call);
262+ if (auto *MSI = dyn_cast<MemSetInst>(II))
174263 lowerMemset (MSI);
264+ else if (II->getIntrinsicID () == Intrinsic::fshl)
265+ lowerFunnelShiftLeft (II);
175266 }
176267 }
177268
@@ -254,7 +345,7 @@ bool SPIRVRegularizeLLVM::regularize() {
254345 }
255346 }
256347 for (Instruction *V : ToErase) {
257- assert (V->user_empty ());
348+ assert (V->user_empty () && " User non-empty \n " );
258349 V->eraseFromParent ();
259350 }
260351 }
0 commit comments