2727#include " llvm/IR/IRBuilder.h"
2828#include " llvm/IR/LegacyPassManager.h"
2929#include " llvm/IR/Verifier.h"
30+ #include " llvm/Support/TargetRegistry.h"
3031#include " llvm/Support/TargetSelect.h"
3132#include " llvm/Support/raw_ostream.h"
3233#include " llvm/Transforms/IPO.h"
@@ -99,6 +100,32 @@ std::vector<int64_t> getTensorSizesWithoutLeadingDim(
99100 return sizes;
100101}
101102
103+ // Set some options, grabbed from Halide + we force fast math atm
104+ static llvm::TargetOptions makeTargetOptions () {
105+ bool use_soft_float_abi = false ;
106+ bool per_instruction_fast_math_flags = true ;
107+
108+ llvm::TargetOptions options;
109+ options.AllowFPOpFusion = per_instruction_fast_math_flags
110+ ? llvm::FPOpFusion::Strict
111+ : llvm::FPOpFusion::Fast;
112+ options.UnsafeFPMath = !per_instruction_fast_math_flags;
113+ options.NoInfsFPMath = !per_instruction_fast_math_flags;
114+ options.NoNaNsFPMath = !per_instruction_fast_math_flags;
115+ options.HonorSignDependentRoundingFPMathOption =
116+ !per_instruction_fast_math_flags;
117+ options.NoZerosInBSS = false ;
118+ options.GuaranteedTailCallOpt = false ;
119+ options.StackAlignmentOverride = 0 ;
120+ options.FunctionSections = true ;
121+ options.UseInitArray = false ;
122+ options.FloatABIType =
123+ use_soft_float_abi ? llvm::FloatABI::Soft : llvm::FloatABI::Hard;
124+ options.RelaxELFRelocations = false ;
125+
126+ return options;
127+ }
128+
102129static constexpr int kOptLevel = 3 ;
103130
104131class CodeGen_TC : public Halide ::Internal::CodeGen_X86 {
@@ -116,6 +143,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
116143 const char * llvm_args[] = {" tc (LLVM argument parsing)" , nullptr };
117144 llvm::cl::ParseCommandLineOptions (
118145 sizeof (llvm_args) / sizeof (*llvm_args) - 1 , llvm_args);
146+
119147 init_context ();
120148 module =
121149 llvm::make_unique<llvm::Module>(" TensorComprehensionsModule" , *context);
@@ -198,33 +226,35 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
198226 }
199227
200228 public:
201- void optimize_module () {
229+ void optimize_module (const llvm::TargetMachine& targetMachine ) {
202230 LOG_IF (INFO, FLAGS_llvm_dump_before_opt)
203231 << " [LLVM-IR] Before optimization:\n "
204232 << toString (module .get ());
205233
206- llvm::legacy::FunctionPassManager functionPassManager (module .get ());
207- llvm::legacy::PassManager modulePassManager;
234+ std::unique_ptr<llvm::TargetMachine> targetMachineWithOptions (
235+ targetMachine.getTarget ().createTargetMachine (
236+ targetMachine.getTargetTriple ().str (),
237+ targetMachine.getTargetCPU (),
238+ targetMachine.getTargetFeatureString (),
239+ makeTargetOptions (),
240+ llvm::Reloc::PIC_,
241+ llvm::CodeModel::Small,
242+ llvm::CodeGenOpt::Aggressive));
208243
209- std::unique_ptr<llvm::TargetMachine> targetMachine =
210- Halide::Internal::make_target_machine (*module );
244+ llvm::legacy::PassManager modulePassManager;
211245 modulePassManager.add (llvm::createTargetTransformInfoWrapperPass (
212- targetMachine ? targetMachine->getTargetIRAnalysis ()
213- : llvm::TargetIRAnalysis ()));
246+ targetMachineWithOptions->getTargetIRAnalysis ()));
247+
248+ llvm::legacy::FunctionPassManager functionPassManager (module .get ());
214249 functionPassManager.add (llvm::createTargetTransformInfoWrapperPass (
215- targetMachine ? targetMachine->getTargetIRAnalysis ()
216- : llvm::TargetIRAnalysis ()));
250+ targetMachineWithOptions->getTargetIRAnalysis ()));
217251
218252 llvm::PassManagerBuilder b;
219253 b.OptLevel = kOptLevel ;
220254 b.Inliner = llvm::createFunctionInliningPass (b.OptLevel , 0 , false );
221255 b.LoopVectorize = true ;
222256 b.SLPVectorize = true ;
223-
224- if (targetMachine) {
225- targetMachine->adjustPassManager (b);
226- }
227-
257+ targetMachineWithOptions->adjustPassManager (b);
228258 b.populateFunctionPassManager (functionPassManager);
229259 b.populateModulePassManager (modulePassManager);
230260
@@ -233,7 +263,6 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
233263 for (llvm::Module::iterator i = module ->begin (); i != module ->end (); i++) {
234264 functionPassManager.run (*i);
235265 }
236-
237266 functionPassManager.doFinalization ();
238267 modulePassManager.run (*module );
239268
@@ -329,10 +358,12 @@ class LLVMCodegen {
329358 LLVMCodegen (
330359 const Scop& scop,
331360 const IteratorMapsType& iteratorMaps,
332- const StmtSubscriptExprMapType& stmtSubscripts)
361+ const StmtSubscriptExprMapType& stmtSubscripts,
362+ const llvm::TargetMachine& targetMachine)
333363 : scop_(scop),
334364 iteratorMaps_ (iteratorMaps),
335365 stmtSubscripts_(stmtSubscripts),
366+ targetMachine(targetMachine),
336367 halide_cg(Halide::Target(
337368 Halide::Target::OSUnknown,
338369 Halide::Target::X86,
@@ -558,6 +589,7 @@ class LLVMCodegen {
558589 std::vector<std::string> argNames_;
559590
560591 public:
592+ const llvm::TargetMachine& targetMachine;
561593 CodeGen_TC halide_cg;
562594};
563595
@@ -601,7 +633,7 @@ isl::ast_node collectIteratorMaps(
601633 return node.set_annotation (stmtId);
602634}
603635
604- IslCodegenRes codegenISL (const Scop& scop) {
636+ static IslCodegenRes codegenISL (const Scop& scop) {
605637 IteratorMapsType iteratorMaps;
606638 StmtSubscriptExprMapType stmtSubscripts;
607639 auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
@@ -625,10 +657,10 @@ IslCodegenRes codegenISL(const Scop& scop) {
625657std::unique_ptr<llvm::Module> emitLLVMKernel (
626658 const std::string& specializedName,
627659 const Scop& scop,
628- const llvm::DataLayout& dataLayout ) {
660+ const llvm::TargetMachine& targetMachine ) {
629661 auto islCg = codegenISL (scop);
630- LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts );
631- cg.halide_cg .get_module ()->setDataLayout (dataLayout );
662+ LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts , targetMachine );
663+ cg.halide_cg .get_module ()->setDataLayout (targetMachine. createDataLayout () );
632664 cg.halide_cg .get_module ()->setTargetTriple (
633665 llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
634666 auto entry = cg.createSignature (
@@ -643,7 +675,7 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
643675 TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
644676 << " LLVM generated module is invalid." << cg.str ().c_str ();
645677
646- cg.halide_cg .optimize_module ();
678+ cg.halide_cg .optimize_module (cg. targetMachine );
647679 if (FLAGS_llvm_dump_asm) {
648680 std::string pat (" /tmp/tcXXXXXX" );
649681 std::vector<char > ifn (pat.begin (), pat.end ());
@@ -662,7 +694,10 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
662694 }
663695 utils::checkedSystemCall (
664696 std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
665- {FLAGS_llvm_dump_asm_options, utils::CPUID::llcFlags (), optFile, std::string (" -o " ) + asmFile});
697+ {FLAGS_llvm_dump_asm_options,
698+ utils::CPUID::llcFlags (),
699+ optFile,
700+ std::string (" -o " ) + asmFile});
666701 {
667702 std::ifstream is (asmFile);
668703 std::string str (
@@ -671,7 +706,6 @@ std::unique_ptr<llvm::Module> emitLLVMKernel(
671706 LOG (INFO) << str;
672707 }
673708 }
674-
675709 return cg.halide_cg .move_module ();
676710}
677711
0 commit comments