Skip to content

Commit b96ae5b

Browse files
committed
Fix coroutines without any suspend point
1 parent 6676d78 commit b96ae5b

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

src/dev/engine/internal/llvmcodebuilder.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id, bool warp) :
3131

3232
std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
3333
{
34+
// Do not create coroutine if there are no yield instructions
35+
if (!m_warp) {
36+
auto it = std::find_if(m_steps.begin(), m_steps.end(), [](const Step &step) { return step.type == Step::Type::Yield; });
37+
38+
if (it == m_steps.end())
39+
m_warp = true;
40+
}
41+
3442
// Create function
3543
// void *f(Target *)
3644
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
@@ -83,6 +91,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
8391
case Step::Type::Yield:
8492
if (!m_warp) {
8593
freeHeap();
94+
m_builder.CreateStore(m_builder.getInt1(true), coro.didSuspend);
8695
llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(m_ctx, "", func);
8796
llvm::Value *noneToken = llvm::ConstantTokenNone::get(m_ctx);
8897
llvm::Value *suspendResult = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1(false) });
@@ -296,11 +305,16 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
296305
// Add final suspend point
297306
if (!m_warp) {
298307
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_ctx, "end", func);
308+
llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create(m_ctx, "finalSuspend", func);
309+
m_builder.CreateCondBr(m_builder.CreateLoad(m_builder.getInt1Ty(), coro.didSuspend), finalSuspendBranch, endBranch);
310+
311+
m_builder.SetInsertPoint(finalSuspendBranch);
299312
llvm::Value *suspendResult =
300313
m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(m_ctx), m_builder.getInt1(true) });
301314
llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2);
302-
sw->addCase(m_builder.getInt8(0), endBranch);
315+
sw->addCase(m_builder.getInt8(0), endBranch); // unreachable
303316
sw->addCase(m_builder.getInt8(1), coro.cleanup);
317+
304318
m_builder.SetInsertPoint(endBranch);
305319
}
306320

@@ -314,7 +328,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
314328
if (m_warp)
315329
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
316330
else
317-
m_builder.CreateBr(coro.cleanup);
331+
m_builder.CreateBr(coro.freeMemRet);
318332

319333
verifyFunction(func);
320334

@@ -514,6 +528,8 @@ LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
514528

515529
// Begin
516530
coro.handle = m_builder.CreateCall(coroBegin, { coroIdRet, alloc });
531+
coro.didSuspend = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, "didSuspend");
532+
m_builder.CreateStore(m_builder.getInt1(false), coro.didSuspend);
517533
llvm::BasicBlock *entry = m_builder.GetInsertBlock();
518534

519535
// Create suspend branch
@@ -522,7 +538,12 @@ LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
522538
m_builder.CreateCall(coroEnd, { coro.handle, m_builder.getInt1(false), llvm::ConstantTokenNone::get(m_ctx) });
523539
m_builder.CreateRet(coro.handle);
524540

525-
// Create free branch
541+
// Create free branches
542+
coro.freeMemRet = llvm::BasicBlock::Create(m_ctx, "freeMemRet", func);
543+
m_builder.SetInsertPoint(coro.freeMemRet);
544+
m_builder.CreateFree(alloc);
545+
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
546+
526547
llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(m_ctx, "free", func);
527548
m_builder.SetInsertPoint(freeBranch);
528549
m_builder.CreateFree(alloc);

src/dev/engine/internal/llvmcodebuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class LLVMCodeBuilder : public ICodeBuilder
103103
llvm::Value *handle = nullptr;
104104
llvm::BasicBlock *suspend = nullptr;
105105
llvm::BasicBlock *cleanup = nullptr;
106+
llvm::BasicBlock *freeMemRet = nullptr;
107+
llvm::Value *didSuspend = nullptr;
106108
};
107109

108110
struct Procedure

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ TEST_F(LLVMCodeBuilderTest, FunctionCalls)
6262
testing::internal::CaptureStdout();
6363
code->run(ctx.get());
6464
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);
65+
ASSERT_TRUE(code->isFinished(ctx.get()));
6566
}
6667
}
6768

@@ -558,6 +559,18 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop)
558559
code->run(ctx.get());
559560
ASSERT_TRUE(testing::internal::GetCapturedStdout().empty());
560561
ASSERT_TRUE(code->isFinished(ctx.get()));
562+
563+
// No warp no-op loop
564+
createBuilder(false);
565+
566+
m_builder->addConstValue(0); // don't yield
567+
m_builder->beginRepeatLoop();
568+
m_builder->endLoop();
569+
570+
code = m_builder->finalize();
571+
ctx = code->createExecutionContext(&m_target);
572+
code->run(ctx.get());
573+
ASSERT_TRUE(code->isFinished(ctx.get()));
561574
}
562575

563576
TEST_F(LLVMCodeBuilderTest, WhileLoop)

0 commit comments

Comments
 (0)