5555using namespace Halide ;
5656
5757namespace tc {
58-
5958namespace polyhedral {
60-
59+ namespace {
6160using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
6261using IteratorMapsType =
6362 std::unordered_map<isl::id, IteratorMapType, isl::IslIdIslHash>;
6463
6564using StmtSubscriptExprMapType =
6665 std::unordered_map<isl::id, std::vector<isl::ast_expr>, isl::IslIdIslHash>;
6766
68- namespace {
67+ struct IslCodegenRes {
68+ IteratorMapsType iteratorMaps;
69+ StmtSubscriptExprMapType stmtSubscripts;
70+ isl::ast_node astNode;
71+ };
72+
73+ isl::ast_node collectIteratorMaps (
74+ isl::ast_node node,
75+ isl::ast_build build,
76+ IteratorMapsType& iteratorMaps,
77+ const Scop& scop,
78+ StmtSubscriptExprMapType& stmtSubscripts) {
79+ auto user = node.as <isl::ast_node_user>();
80+ TC_CHECK (user);
81+ auto expr = user.get_expr ().as <isl::ast_expr_op>();
82+ auto schedule = build.get_schedule ();
83+ auto scheduleMap = isl::map::from_union_map (schedule);
84+
85+ auto stmtId = expr.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
86+ TC_CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
87+ auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
88+ auto tuple = scop.halide .domains .at (stmtId).tuple ;
89+ auto & stmtIteratorMap = iteratorMaps[stmtId];
90+ for (int i = 0 ; i < tuple.size (); ++i) {
91+ auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
92+ stmtIteratorMap.emplace (tuple.get_id (i).get_name (), expr);
93+ }
94+ auto & subscripts = stmtSubscripts[stmtId];
95+ auto provide =
96+ scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
97+ for (auto e : provide->args ) {
98+ const auto & map = iteratorMap;
99+ auto aff = scop.makeIslAffFromStmtExpr (stmtId, e);
100+ auto pulled = isl::pw_aff (aff).pullback (map);
101+ TC_CHECK_EQ (pulled.n_piece (), 1 );
102+ subscripts.push_back (build.expr_from (pulled));
103+ }
104+ return node.set_annotation (stmtId);
105+ }
106+
107+ static IslCodegenRes codegenISL (const Scop& scop) {
108+ IteratorMapsType iteratorMaps;
109+ StmtSubscriptExprMapType stmtSubscripts;
110+ auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
111+ isl::ast_node n, isl::ast_build b) -> isl::ast_node {
112+ auto & uv = iteratorMaps;
113+ return collectIteratorMaps (n, b, uv, scop, stmtSubscripts);
114+ };
115+
116+ auto schedule = detail::toIslSchedule (scop.scheduleRoot ());
117+ auto astBuild = isl::ast_build (schedule.get_ctx ());
118+ astBuild = astBuild.set_at_each_domain (collect);
119+ auto root = scop.scheduleRoot ();
120+ astBuild = astBuild.set_iterators (Codegen::makeLoopIterators (root));
121+ auto astNode = astBuild.node_from (schedule);
122+ return {
123+ std::move (iteratorMaps), std::move (stmtSubscripts), std::move (astNode)};
124+ }
69125
70126thread_local llvm::LLVMContext llvmCtx;
71127
@@ -324,6 +380,71 @@ Halide::Expr CodeGen_TC::makeHalideExpr(isl::ast_expr expr) {
324380}
325381
326382class LLVMCodegen {
383+ public:
384+ LLVMCodegen (
385+ const std::string& specializedName,
386+ const Scop& scop,
387+ const llvm::TargetMachine& targetMachine)
388+ : scop_(scop),
389+ islCg_ (codegenISL(scop_)),
390+ iteratorMaps_(islCg_.iteratorMaps),
391+ stmtSubscripts_(islCg_.stmtSubscripts),
392+ targetMachine(targetMachine),
393+ // we don't use Halide to tinker with llvm::Module optimization so we
394+ // tthe Halide target can be whatever.
395+ halide_cg(Halide::get_host_target()) {
396+ halide_cg.set_context (llvmCtx);
397+ halide_cg.init_module ();
398+ halide_cg.get_module ()->setDataLayout (targetMachine.createDataLayout ());
399+ halide_cg.get_module ()->setTargetTriple (
400+ targetMachine.getTargetTriple ().str ());
401+ auto entry = createSignature (
402+ scop.halide .inputs ,
403+ scop.halide .outputs ,
404+ scop.halide .params ,
405+ specializedName);
406+ auto exit = emitAst (islCg_.astNode , entry);
407+ halide_cg.get_builder ().SetInsertPoint (exit);
408+ halide_cg.get_builder ().CreateRetVoid ();
409+
410+ TC_CHECK (!llvm::verifyModule (*halide_cg.get_module ()))
411+ << " LLVM generated module is invalid." << str ().c_str ();
412+
413+ halide_cg.optimize_module (targetMachine);
414+
415+ if (FLAGS_llvm_dump_asm) {
416+ std::string pat (" /tmp/tcXXXXXX" );
417+ std::vector<char > ifn (pat.begin (), pat.end ());
418+ TC_CHECK_GE (mkstemp (ifn.data ()), 0 ); // string.c_str is const char*
419+ std::string fileName (ifn.begin (), ifn.end ());
420+ std::string optFile = fileName + " -opt.ll" ;
421+ std::string asmFile = fileName + " .s" ;
422+ // cstdio's std::remove to delete files
423+ tc::ScopeGuard sgi ([&]() {
424+ std::remove (optFile.c_str ());
425+ std::remove (asmFile.c_str ());
426+ });
427+ {
428+ std::ofstream ostream (optFile, std::ios::binary);
429+ ostream << str ();
430+ }
431+ utils::checkedSystemCall (
432+ std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
433+ {FLAGS_llvm_dump_asm_options,
434+ utils::CPUID::llcFlags (),
435+ optFile,
436+ std::string (" -o " ) + asmFile});
437+
438+ std::ifstream is (asmFile);
439+ std::string str (
440+ (std::istreambuf_iterator<char >(is)),
441+ std::istreambuf_iterator<char >());
442+ LOG (INFO) << " Dumping asm for: " << utils::CPUID::llcFlags () << " \n "
443+ << str;
444+ }
445+ }
446+
447+ private:
327448 void collectTensor (const Halide::OutputImageParam& t) {
328449 auto sizes = getTensorSizesWithoutLeadingDim (t, scop_.context ());
329450 if (not sizes.empty ()) {
@@ -354,23 +475,16 @@ class LLVMCodegen {
354475 }
355476 }
356477
357- public:
358- LLVMCodegen (
359- const Scop& scop,
360- const IteratorMapsType& iteratorMaps,
361- const StmtSubscriptExprMapType& stmtSubscripts,
362- const llvm::TargetMachine& targetMachine)
363- : scop_(scop),
364- iteratorMaps_ (iteratorMaps),
365- stmtSubscripts_(stmtSubscripts),
366- targetMachine(targetMachine),
367- halide_cg(Halide::Target(
368- Halide::Target::OSUnknown,
369- Halide::Target::X86,
370- 64 )) {
371- halide_cg.set_context (llvmCtx);
372-
373- halide_cg.init_module ();
478+ llvm::Type* makePtrToArrayType (
479+ llvm::Type* baseTy,
480+ const std::vector<int64_t >& sizes) {
481+ TC_CHECK_GE (sizes.size (), 1u );
482+ TC_CHECK (baseTy);
483+ llvm::Type* arrTy = llvm::ArrayType::get (baseTy, sizes.back ());
484+ for (auto s = sizes.rbegin () + 1 ; s != sizes.rend (); ++s) {
485+ arrTy = llvm::ArrayType::get (arrTy, *s);
486+ }
487+ return arrTy->getPointerTo ();
374488 }
375489
376490 // This creates a signature of the form:
@@ -451,19 +565,6 @@ class LLVMCodegen {
451565 return nullptr ;
452566 }
453567
454- private:
455- llvm::Type* makePtrToArrayType (
456- llvm::Type* baseTy,
457- const std::vector<int64_t >& sizes) {
458- TC_CHECK_GE (sizes.size (), 1u );
459- TC_CHECK (baseTy);
460- llvm::Type* arrTy = llvm::ArrayType::get (baseTy, sizes.back ());
461- for (auto s = sizes.rbegin () + 1 ; s != sizes.rend (); ++s) {
462- arrTy = llvm::ArrayType::get (arrTy, *s);
463- }
464- return arrTy->getPointerTo ();
465- }
466-
467568 llvm::BasicBlock* emitIf (
468569 isl::ast_node_if node,
469570 llvm::BasicBlock* insertionPoint) {
@@ -582,6 +683,7 @@ class LLVMCodegen {
582683
583684 private:
584685 const Scop& scop_;
686+ const IslCodegenRes islCg_;
585687 const IteratorMapsType& iteratorMaps_;
586688 const StmtSubscriptExprMapType& stmtSubscripts_;
587689
@@ -592,120 +694,13 @@ class LLVMCodegen {
592694 const llvm::TargetMachine& targetMachine;
593695 CodeGen_TC halide_cg;
594696};
595-
596- struct IslCodegenRes {
597- IteratorMapsType iteratorMaps;
598- StmtSubscriptExprMapType stmtSubscripts;
599- isl::ast_node astNode;
600- };
601-
602- isl::ast_node collectIteratorMaps (
603- isl::ast_node node,
604- isl::ast_build build,
605- IteratorMapsType& iteratorMaps,
606- const Scop& scop,
607- StmtSubscriptExprMapType& stmtSubscripts) {
608- auto user = node.as <isl::ast_node_user>();
609- TC_CHECK (user);
610- auto expr = user.get_expr ().as <isl::ast_expr_op>();
611- auto schedule = build.get_schedule ();
612- auto scheduleMap = isl::map::from_union_map (schedule);
613-
614- auto stmtId = expr.get_arg (0 ).as <isl::ast_expr_id>().get_id ();
615- TC_CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
616- auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
617- auto tuple = scop.halide .domains .at (stmtId).tuple ;
618- auto & stmtIteratorMap = iteratorMaps[stmtId];
619- for (int i = 0 ; i < tuple.size (); ++i) {
620- auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
621- stmtIteratorMap.emplace (tuple.get_id (i).get_name (), expr);
622- }
623- auto & subscripts = stmtSubscripts[stmtId];
624- auto provide =
625- scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
626- for (auto e : provide->args ) {
627- const auto & map = iteratorMap;
628- auto aff = scop.makeIslAffFromStmtExpr (stmtId, e);
629- auto pulled = isl::pw_aff (aff).pullback (map);
630- TC_CHECK_EQ (pulled.n_piece (), 1 );
631- subscripts.push_back (build.expr_from (pulled));
632- }
633- return node.set_annotation (stmtId);
634- }
635-
636- static IslCodegenRes codegenISL (const Scop& scop) {
637- IteratorMapsType iteratorMaps;
638- StmtSubscriptExprMapType stmtSubscripts;
639- auto collect = [&iteratorMaps, &scop, &stmtSubscripts](
640- isl::ast_node n, isl::ast_build b) -> isl::ast_node {
641- auto & uv = iteratorMaps;
642- return collectIteratorMaps (n, b, uv, scop, stmtSubscripts);
643- };
644-
645- auto schedule = detail::toIslSchedule (scop.scheduleRoot ());
646- auto astBuild = isl::ast_build (schedule.get_ctx ());
647- astBuild = astBuild.set_at_each_domain (collect);
648- auto root = scop.scheduleRoot ();
649- astBuild = astBuild.set_iterators (Codegen::makeLoopIterators (root));
650- auto astNode = astBuild.node_from (schedule);
651- return {
652- std::move (iteratorMaps), std::move (stmtSubscripts), std::move (astNode)};
653- }
654-
655697} // namespace
656698
657699std::unique_ptr<llvm::Module> emitLLVMKernel (
658700 const std::string& specializedName,
659701 const Scop& scop,
660702 const llvm::TargetMachine& targetMachine) {
661- auto islCg = codegenISL (scop);
662- LLVMCodegen cg (scop, islCg.iteratorMaps , islCg.stmtSubscripts , targetMachine);
663- cg.halide_cg .get_module ()->setDataLayout (targetMachine.createDataLayout ());
664- cg.halide_cg .get_module ()->setTargetTriple (
665- llvm::EngineBuilder ().selectTarget ()->getTargetTriple ().str ());
666- auto entry = cg.createSignature (
667- scop.halide .inputs ,
668- scop.halide .outputs ,
669- scop.halide .params ,
670- specializedName);
671- auto exit = cg.emitAst (islCg.astNode , entry);
672- cg.halide_cg .get_builder ().SetInsertPoint (exit);
673- cg.halide_cg .get_builder ().CreateRetVoid ();
674-
675- TC_CHECK (!llvm::verifyModule (*cg.halide_cg .get_module ()))
676- << " LLVM generated module is invalid." << cg.str ().c_str ();
677-
678- cg.halide_cg .optimize_module (cg.targetMachine );
679- if (FLAGS_llvm_dump_asm) {
680- std::string pat (" /tmp/tcXXXXXX" );
681- std::vector<char > ifn (pat.begin (), pat.end ());
682- TC_CHECK_GE (mkstemp (ifn.data ()), 0 ); // string.c_str is const char*
683- std::string fileName (ifn.begin (), ifn.end ());
684- std::string optFile = fileName + " -opt.ll" ;
685- std::string asmFile = fileName + " .s" ;
686- // cstdio's std::remove to delete files
687- tc::ScopeGuard sgi ([&]() {
688- std::remove (optFile.c_str ());
689- std::remove (asmFile.c_str ());
690- });
691- {
692- std::ofstream ostream (optFile, std::ios::binary);
693- ostream << cg.str ();
694- }
695- utils::checkedSystemCall (
696- std::string (TC_STRINGIFY (TC_LLVM_BIN_DIR)) + " /llc" ,
697- {FLAGS_llvm_dump_asm_options,
698- utils::CPUID::llcFlags (),
699- optFile,
700- std::string (" -o " ) + asmFile});
701- {
702- std::ifstream is (asmFile);
703- std::string str (
704- (std::istreambuf_iterator<char >(is)),
705- std::istreambuf_iterator<char >());
706- LOG (INFO) << str;
707- }
708- }
703+ LLVMCodegen cg (specializedName, scop, targetMachine);
709704 return cg.halide_cg .move_module ();
710705}
711706
0 commit comments