@@ -31,6 +31,14 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id, bool warp) :
3131
3232std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize ()
3333{
34+ // Do not create coroutine if there are no yield instructions
35+ if (!m_warp) {
36+ auto it = std::find_if (m_steps.begin (), m_steps.end (), [](const Step &step) { return step.type == Step::Type::Yield; });
37+
38+ if (it == m_steps.end ())
39+ m_warp = true ;
40+ }
41+
3442 // Create function
3543 // void *f(Target *)
3644 llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
@@ -83,6 +91,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
8391 case Step::Type::Yield:
8492 if (!m_warp) {
8593 freeHeap ();
94+ m_builder.CreateStore (m_builder.getInt1 (true ), coro.didSuspend );
8695 llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create (m_ctx, " " , func);
8796 llvm::Value *noneToken = llvm::ConstantTokenNone::get (m_ctx);
8897 llvm::Value *suspendResult = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1 (false ) });
@@ -296,11 +305,16 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
296305 // Add final suspend point
297306 if (!m_warp) {
298307 llvm::BasicBlock *endBranch = llvm::BasicBlock::Create (m_ctx, " end" , func);
308+ llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create (m_ctx, " finalSuspend" , func);
309+ m_builder.CreateCondBr (m_builder.CreateLoad (m_builder.getInt1Ty (), coro.didSuspend ), finalSuspendBranch, endBranch);
310+
311+ m_builder.SetInsertPoint (finalSuspendBranch);
299312 llvm::Value *suspendResult =
300313 m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get (m_ctx), m_builder.getInt1 (true ) });
301314 llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
302- sw->addCase (m_builder.getInt8 (0 ), endBranch);
315+ sw->addCase (m_builder.getInt8 (0 ), endBranch); // unreachable
303316 sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
317+
304318 m_builder.SetInsertPoint (endBranch);
305319 }
306320
@@ -314,7 +328,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
314328 if (m_warp)
315329 m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
316330 else
317- m_builder.CreateBr (coro.cleanup );
331+ m_builder.CreateBr (coro.freeMemRet );
318332
319333 verifyFunction (func);
320334
@@ -514,6 +528,8 @@ LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
514528
515529 // Begin
516530 coro.handle = m_builder.CreateCall (coroBegin, { coroIdRet, alloc });
531+ coro.didSuspend = m_builder.CreateAlloca (m_builder.getInt1Ty (), nullptr , " didSuspend" );
532+ m_builder.CreateStore (m_builder.getInt1 (false ), coro.didSuspend );
517533 llvm::BasicBlock *entry = m_builder.GetInsertBlock ();
518534
519535 // Create suspend branch
@@ -522,7 +538,12 @@ LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
522538 m_builder.CreateCall (coroEnd, { coro.handle , m_builder.getInt1 (false ), llvm::ConstantTokenNone::get (m_ctx) });
523539 m_builder.CreateRet (coro.handle );
524540
525- // Create free branch
541+ // Create free branches
542+ coro.freeMemRet = llvm::BasicBlock::Create (m_ctx, " freeMemRet" , func);
543+ m_builder.SetInsertPoint (coro.freeMemRet );
544+ m_builder.CreateFree (alloc);
545+ m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
546+
526547 llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create (m_ctx, " free" , func);
527548 m_builder.SetInsertPoint (freeBranch);
528549 m_builder.CreateFree (alloc);
0 commit comments