@@ -340,7 +340,7 @@ class LLVMCodegen {
340340
341341 // This creates a signature of the form:
342342 // input_data_types, output_data_types, parameters
343- void createSignature (
343+ llvm::BasicBlock* createSignature (
344344 const std::vector<Halide::ImageParam>& inputs,
345345 const std::vector<Halide::OutputImageParam>& outputs,
346346 const std::vector<Halide::Internal::Parameter>& params,
@@ -383,40 +383,37 @@ class LLVMCodegen {
383383 it->addAttr (llvm::Attribute::ReadOnly);
384384 }
385385
386- auto entryBB_ = llvm::BasicBlock::Create (llvmCtx, " entry" , function);
387- halide_cg.get_builder ().SetInsertPoint (entryBB_);
386+ return llvm::BasicBlock::Create (llvmCtx, " entry" , function);
388387 }
389388
390- void CodeGen (isl::ast_node node) {
391- emitAst (node);
392- halide_cg.get_builder ().CreateRetVoid ();
393-
394- if (llvm::verifyModule (*halide_cg.get_module ())) {
395- LOG (ERROR) << str ();
396- llvm::verifyModule (*halide_cg.get_module (), &llvm::outs ());
397- throw std::runtime_error (" LLVM generated module is invalid." );
398- }
399- }
400-
401- llvm::BasicBlock* emitAst (isl::ast_node node) {
389+ // This is the main entry point to emit pieces of LLVM IR
390+ // LLVM IR insertion is stateful, configured by SetInsertPoint
391+ // We make this an explicit parameter to avoid implicit conventions
392+ // All TC IR builder methods take an explicit insertionPoint.
393+ // The invariant in all emit* (except for emitAst) is that:
394+ // TC_CHECK_EQ(halide_cg.get_builder().GetInsertBlock(), insertionPoint);
395+ llvm::BasicBlock* emitAst (
396+ isl::ast_node node,
397+ llvm::BasicBlock* insertionPoint) {
398+ halide_cg.get_builder ().SetInsertPoint (insertionPoint);
402399 if (auto forNode = node.as <isl::ast_node_for>()) {
403- return emitFor (forNode);
400+ return emitFor (forNode, insertionPoint );
404401 } else if (auto userNode = node.as <isl::ast_node_user>()) {
405- return emitStmt (userNode);
402+ return emitStmt (userNode, insertionPoint );
406403 } else if (auto blockNode = node.as <isl::ast_node_block>()) {
407- llvm::BasicBlock* curBB;
404+ llvm::BasicBlock* curBB = insertionPoint ;
408405 for (auto child : blockNode.get_children ()) {
409- curBB = emitAst (child);
406+ curBB = emitAst (child, curBB );
410407 }
411408 return curBB;
412409 } else {
413410 if (auto cond = node.as <isl::ast_node_if>()) {
414- return emitIf (cond);
411+ return emitIf (cond, insertionPoint );
415412 } else {
416413 LOG (FATAL) << " NYI " << node << std::endl;
417414 }
418- return static_cast <llvm::BasicBlock*>(nullptr ); // avoid warning
419415 }
416+ return nullptr ;
420417 }
421418
422419 private:
@@ -432,18 +429,19 @@ class LLVMCodegen {
432429 return arrTy->getPointerTo ();
433430 }
434431
435- llvm::BasicBlock* emitIf (isl::ast_node_if node) {
436- auto * incoming = halide_cg.get_builder ().GetInsertBlock ();
437- auto * function = incoming->getParent ();
432+ llvm::BasicBlock* emitIf (
433+ isl::ast_node_if node,
434+ llvm::BasicBlock* insertionPoint) {
435+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
436+ auto * function = insertionPoint->getParent ();
438437
439438 llvm::Value* condVal = halide_cg.codegen (node.get_cond ());
440439 auto * thenBB = llvm::BasicBlock::Create (llvmCtx, " then" , function);
441440 // Recursively emit "then" in a new thenBB
442- halide_cg.get_builder ().SetInsertPoint (thenBB);
443- auto innerBB = emitAst (node.get_then ());
441+ auto innerBB = emitAst (node.get_then (), thenBB);
444442
445443 // outer -> thenBB
446- halide_cg.get_builder ().SetInsertPoint (incoming );
444+ halide_cg.get_builder ().SetInsertPoint (insertionPoint );
447445 // outer ---------> if_exit
448446 // TODO: When we support "else", go to elseBB instead of exit
449447 auto * exit = llvm::BasicBlock::Create (llvmCtx, " if_exit" , function);
@@ -456,17 +454,17 @@ class LLVMCodegen {
456454 // Else is often empty in the absence of full tile extraction
457455 if (node.has_else ()) {
458456 LOG (FATAL) << " NYI: else conditional branch" ;
459- return halide_cg. get_builder (). GetInsertBlock () ;
457+ return exit ;
460458 }
461459
462- // Set the insertion point to if_exit
463- halide_cg.get_builder ().SetInsertPoint (exit);
464- return halide_cg.get_builder ().GetInsertBlock ();
460+ return exit;
465461 }
466462
467- llvm::BasicBlock* emitFor (isl::ast_node_for node) {
468- auto * incoming = halide_cg.get_builder ().GetInsertBlock ();
469- auto * function = incoming->getParent ();
463+ llvm::BasicBlock* emitFor (
464+ isl::ast_node_for node,
465+ llvm::BasicBlock* insertionPoint) {
466+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
467+ auto * function = insertionPoint->getParent ();
470468 auto * headerBB = llvm::BasicBlock::Create (llvmCtx, " loop_header" , function);
471469 auto * loopBodyBB = llvm::BasicBlock::Create (llvmCtx, " loop_body" , function);
472470 auto * loopLatchBB =
@@ -485,16 +483,15 @@ class LLVMCodegen {
485483 phi = halide_cg.get_builder ().CreatePHI (
486484 initVal->getType (), 2 , iterator.get_name ());
487485 halide_cg.sym_push (iterator.get_name (), phi);
488- phi->addIncoming (initVal, incoming );
486+ phi->addIncoming (initVal, insertionPoint );
489487
490488 auto cond = halide_cg.codegen (node.get_cond ());
491489 halide_cg.get_builder ().CreateCondBr (cond, loopBodyBB, loopExitBB);
492490 }
493491
494492 // Create Body
495493 {
496- halide_cg.get_builder ().SetInsertPoint (loopBodyBB);
497- auto * currentBB = emitAst (node.get_body ());
494+ auto * currentBB = emitAst (node.get_body (), loopBodyBB);
498495 halide_cg.get_builder ().SetInsertPoint (currentBB);
499496 halide_cg.get_builder ().CreateBr (loopLatchBB);
500497 }
@@ -508,12 +505,14 @@ class LLVMCodegen {
508505 halide_cg.get_builder ().CreateBr (headerBB);
509506 }
510507
511- halide_cg.get_builder ().SetInsertPoint (loopExitBB);
512508 halide_cg.sym_pop (iterator.get_name ());
513- return halide_cg. get_builder (). GetInsertBlock () ;
509+ return loopExitBB ;
514510 }
515511
516- llvm::BasicBlock* emitStmt (isl::ast_node_user node) {
512+ llvm::BasicBlock* emitStmt (
513+ isl::ast_node_user node,
514+ llvm::BasicBlock* insertionPoint) {
515+ TC_CHECK_EQ (halide_cg.get_builder ().GetInsertBlock (), insertionPoint);
517516 isl::ast_expr_op usrExp = node.get_expr ().as <isl::ast_expr_op>();
518517 auto id = usrExp.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
519518 auto provide = scop_.halide .statements .at (id);
@@ -535,6 +534,9 @@ class LLVMCodegen {
535534
536535 llvm::Value* rhs = halide_cg.codegen (op->values [0 ]);
537536 halide_cg.get_builder ().CreateStore (rhs, destAddr);
537+ // We must return halide_cg.get_builder().GetInsertBlock() because
538+ // Halide does not adhere to our conventions and when it emits multiple
539+ // blocks things may go haywire.
538540 return halide_cg.get_builder ().GetInsertBlock ();
539541 }
540542
@@ -625,12 +627,18 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
625627 cg.halide_cg .get_module ()->setDataLayout (dataLayout);
626628 cg.halide_cg .get_module ()->setTargetTriple (
627629 llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
628- cg.createSignature (
630+ auto entry = cg.createSignature (
629631 scop.halide .inputs ,
630632 scop.halide .outputs ,
631633 scop.halide .params ,
632634 specializedName);
633- cg.CodeGen (islCg.astNode );
635+ auto exit = cg.emitAst (islCg.astNode , entry);
636+ cg.halide_cg .get_builder ().SetInsertPoint (exit);
637+ cg.halide_cg .get_builder ().CreateRetVoid ();
638+
639+ TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
640+ << " LLVM generated module is invalid." << cg.str ().c_str ();
641+
634642 cg.halide_cg .optimize_module ();
635643 return cg.halide_cg .move_module ();
636644}
0 commit comments