2020
2121#include " llvm/ADT/STLExtras.h"
2222#include " llvm/ADT/SmallVector.h"
23+ #include " llvm/Analysis/TargetTransformInfo.h"
2324#include " llvm/ExecutionEngine/ExecutionEngine.h"
2425#include " llvm/IR/IRBuilder.h"
26+ #include " llvm/IR/LegacyPassManager.h"
2527#include " llvm/IR/Verifier.h"
2628#include " llvm/Support/TargetSelect.h"
2729#include " llvm/Support/raw_ostream.h"
30+ #include " llvm/Transforms/IPO/PassManagerBuilder.h"
31+ #include " llvm/Transforms/IPO.h"
32+ #include " llvm/Transforms/Tapir/CilkABI.h"
2833
2934#include " Halide/Halide.h"
3035
@@ -202,7 +207,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
202207
203208 using CodeGen_X86::codegen;
204209 using CodeGen_X86::llvm_type_of;
205- using CodeGen_X86::optimize_module;
206210 using CodeGen_X86::sym_get;
207211 using CodeGen_X86::sym_pop;
208212 using CodeGen_X86::sym_push;
@@ -294,6 +298,81 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
294298
295299 value = sym_get (name);
296300 }
301+ public:
302+ void optimize_module () {
303+ Halide::Internal::debug (3 ) << " Optimizing module\n " ;
304+
305+ if (Halide::Internal::debug::debug_level () >= 3 ) {
306+ #if LLVM_VERSION >= 50
307+ module ->print (dbgs (), nullptr , false , true );
308+ #else
309+ module ->dump ();
310+ #endif
311+ }
312+
313+ // We override PassManager::add so that we have an opportunity to
314+ // blacklist problematic LLVM passes.
315+ class MyFunctionPassManager : public llvm ::legacy::FunctionPassManager {
316+ public:
317+ MyFunctionPassManager (llvm::Module *m) : llvm::legacy::FunctionPassManager(m) {}
318+ virtual void add (llvm::Pass *p) override {
319+ Halide::Internal::debug (2 ) << " Adding function pass: " << p->getPassName ().str () << " \n " ;
320+ llvm::legacy::FunctionPassManager::add (p);
321+ }
322+ };
323+
324+ class MyModulePassManager : public llvm ::legacy::PassManager {
325+ public:
326+ virtual void add (llvm::Pass *p) override {
327+ Halide::Internal::debug (2 ) << " Adding module pass: " << p->getPassName ().str () << " \n " ;
328+ llvm::legacy::PassManager::add (p);
329+ }
330+ };
331+
332+ MyFunctionPassManager function_pass_manager (module .get ());
333+ MyModulePassManager module_pass_manager;
334+
335+ std::unique_ptr<llvm::TargetMachine> TM = Halide::Internal::make_target_machine (*module );
336+ module_pass_manager.add (llvm::createTargetTransformInfoWrapperPass (TM ? TM->getTargetIRAnalysis () : llvm::TargetIRAnalysis ()));
337+ function_pass_manager.add (llvm::createTargetTransformInfoWrapperPass (TM ? TM->getTargetIRAnalysis () : llvm::TargetIRAnalysis ()));
338+
339+ llvm::PassManagerBuilder b;
340+ b.OptLevel = 3 ;
341+ b.tapirTarget = new llvm::tapir::CilkABI ();
342+ #if LLVM_VERSION >= 50
343+ b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 , false );
344+ #else
345+ b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 );
346+ #endif
347+ b.LoopVectorize = true ;
348+ b.SLPVectorize = true ;
349+
350+ #if LLVM_VERSION >= 50
351+ if (TM) {
352+ TM->adjustPassManager (b);
353+ }
354+ #endif
355+
356+ b.populateFunctionPassManager (function_pass_manager);
357+ b.populateModulePassManager (module_pass_manager);
358+
359+ // Run optimization passes
360+ function_pass_manager.doInitialization ();
361+ for (llvm::Module::iterator i = module ->begin (); i != module ->end (); i++) {
362+ function_pass_manager.run (*i);
363+ }
364+ function_pass_manager.doFinalization ();
365+ module_pass_manager.run (*module );
366+
367+ Halide::Internal::debug (3 ) << " After LLVM optimizations:\n " ;
368+ if (Halide::Internal::debug::debug_level () >= 2 ) {
369+ #if LLVM_VERSION >= 50
370+ module ->print (dbgs (), nullptr , false , true );
371+ #else
372+ module ->dump ();
373+ #endif
374+ }
375+ }
297376};
298377
299378class LLVMCodegen {
@@ -451,6 +530,17 @@ class LLVMCodegen {
451530 llvm::BasicBlock::Create (llvmCtx, " loop_latch" , function);
452531 auto * loopExitBB = llvm::BasicBlock::Create (llvmCtx, " loop_exit" , function);
453532
533+ bool parallel = true ;
534+
535+ llvm::Value* SyncRegion = nullptr ;
536+ if (parallel) {
537+ SyncRegion = halide_cg.get_builder ().CreateCall (
538+ llvm::Intrinsic::getDeclaration (function->getParent (), llvm::Intrinsic::syncregion_start),
539+ {},
540+ " syncreg"
541+ );
542+ }
543+
454544 halide_cg.get_builder ().CreateBr (headerBB);
455545
456546 llvm::PHINode* phi = nullptr ;
@@ -498,9 +588,20 @@ class LLVMCodegen {
498588 // Create Body
499589 {
500590 halide_cg.get_builder ().SetInsertPoint (loopBodyBB);
591+
592+ if (parallel) {
593+ auto * detachedBB = llvm::BasicBlock::Create (llvmCtx, " det.achd" , function);
594+ halide_cg.get_builder ().CreateDetach (detachedBB, loopLatchBB, SyncRegion);
595+ halide_cg.get_builder ().SetInsertPoint (detachedBB);
596+ }
501597 auto * currentBB = emitAst (node.for_get_body ());
502598 halide_cg.get_builder ().SetInsertPoint (currentBB);
503- halide_cg.get_builder ().CreateBr (loopLatchBB);
599+
600+ if (parallel) {
601+ halide_cg.get_builder ().CreateReattach (loopLatchBB, SyncRegion);
602+ } else {
603+ halide_cg.get_builder ().CreateBr (loopLatchBB);
604+ }
504605 }
505606
506607 // Create Latch
@@ -516,6 +617,11 @@ class LLVMCodegen {
516617
517618 halide_cg.get_builder ().SetInsertPoint (loopExitBB);
518619 halide_cg.sym_pop (node.for_get_iterator ().get_id ().get_name ());
620+ if (parallel) {
621+ auto * syncBB = llvm::BasicBlock::Create (llvmCtx, " synced" , function);
622+ halide_cg.get_builder ().CreateSync (syncBB, SyncRegion);
623+ halide_cg.get_builder ().SetInsertPoint (syncBB);
624+ }
519625 return halide_cg.get_builder ().GetInsertBlock ();
520626 }
521627
0 commit comments