@@ -72,8 +72,9 @@ isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
7272
7373namespace polyhedral {
7474
75+ using IteratorMapType = std::unordered_map<std::string, isl::ast_expr>;
7576using IteratorMapsType =
76- std::unordered_map<isl::id, isl::pw_multi_aff , isl::IslIdIslHash>;
77+ std::unordered_map<isl::id, IteratorMapType , isl::IslIdIslHash>;
7778
7879using IteratorLLVMValueMapType =
7980 std::unordered_map<isl::id, llvm::Value*, isl::IslIdIslHash>;
@@ -96,14 +97,6 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
9697 return llvm::ConstantInt::get (llvm::Type::getInt64Ty (llvmCtx), v, true );
9798}
9899
99- isl::aff extractAff (isl::pw_multi_aff pma) {
100- isl::PMA pma_ (pma);
101- CHECK_EQ (pma_.size (), 1 );
102- isl::MA ma (pma_[0 ].second );
103- CHECK_EQ (ma.size (), 1 );
104- return ma[0 ];
105- }
106-
107100int64_t IslExprToSInt (isl::ast_expr e) {
108101 CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
109102 assert (sizeof (long ) <= 8 ); // long is assumed to fit to 64bits
@@ -214,7 +207,7 @@ static constexpr int kOptLevel = 3;
214207
215208class CodeGen_TC : public Halide ::Internal::CodeGen_X86 {
216209 public:
217- const isl::pw_multi_aff * iteratorMap_;
210+ const IteratorMapType * iteratorMap_;
218211 CodeGen_TC (Target t) : CodeGen_X86(t) {}
219212
220213 using CodeGen_X86::codegen;
@@ -249,6 +242,11 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
249242 return std::move (module );
250243 }
251244
245+ // Convert an isl AST expression into an llvm::Value.
246+ // Only expressions that consist of a pure identifier or
247+ // a pure integer constant are currently supported.
248+ llvm::Value* getValue (isl::ast_expr expr);
249+
252250 protected:
253251 using CodeGen_X86::visit;
254252 void visit (const Halide::Internal::Call* call) override {
@@ -272,44 +270,7 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
272270 }
273271 }
274272 void visit (const Halide::Internal::Variable* op) override {
275- auto aff = halide2isl::makeIslAffFromExpr (
276- iteratorMap_->get_space ().range (), Halide::Expr (op));
277-
278- auto subscriptPma = isl::pw_aff (aff).pullback (*iteratorMap_);
279- auto subscriptAff = extractAff (subscriptPma);
280-
281- // sanity checks
282- CHECK_EQ (subscriptAff.dim (isl::dim_type::div), 0 );
283- CHECK_EQ (subscriptAff.dim (isl::dim_type::out), 1 );
284- for (int d = 0 ; d < subscriptAff.dim (isl::dim_type::param); ++d) {
285- auto v = subscriptAff.get_coefficient_val (isl::dim_type::param, d);
286- CHECK (v.is_zero ());
287- }
288-
289- llvm::Optional<int > posOne;
290- int sum = 0 ;
291- for (int d = 0 ; d < subscriptAff.dim (isl::dim_type::in); ++d) {
292- auto v = subscriptAff.get_coefficient_val (isl::dim_type::in, d);
293- CHECK (v.is_zero () or v.is_one ());
294- if (v.is_zero ()) {
295- continue ;
296- }
297- ++sum;
298- posOne = d;
299- }
300- CHECK_LE (sum, 1 );
301-
302- if (sum == 0 ) {
303- value =
304- getLLVMConstantSignedInt64 (toSInt (subscriptAff.get_constant_val ()));
305- return ;
306- }
307- CHECK (posOne);
308-
309- std::string name (
310- isl_aff_get_dim_name (subscriptAff.get (), isl_dim_in, *posOne));
311-
312- value = sym_get (name);
273+ value = getValue (iteratorMap_->at (op->name ));
313274 }
314275
315276 public:
@@ -361,6 +322,21 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
361322 }
362323};
363324
325+ llvm::Value* CodeGen_TC::getValue (isl::ast_expr expr) {
326+ switch (isl_ast_expr_get_type (expr.get ())) {
327+ case isl_ast_expr_type::isl_ast_expr_id:
328+ return sym_get (expr.get_id ().get_name ());
329+ case isl_ast_expr_type::isl_ast_expr_int: {
330+ auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
331+ CHECK (val.is_int ());
332+ return getLLVMConstantSignedInt64 (val.get_num_si ());
333+ }
334+ default :
335+ LOG (FATAL) << " NYI" ;
336+ return nullptr ;
337+ }
338+ }
339+
364340class LLVMCodegen {
365341 void collectTensor (const Halide::OutputImageParam& t) {
366342 auto sizes =
@@ -638,22 +614,7 @@ class LLVMCodegen {
638614 llvm::SmallVector<llvm::Value*, 5 > subscriptValues;
639615
640616 for (const auto & subscript : subscripts) {
641- switch (isl_ast_expr_get_type (subscript.get ())) {
642- case isl_ast_expr_type::isl_ast_expr_id: {
643- subscriptValues.push_back (
644- halide_cg.sym_get (subscript.get_id ().get_name ()));
645- break ;
646- }
647- case isl_ast_expr_type::isl_ast_expr_int: {
648- auto val = isl::manage (isl_ast_expr_get_val (subscript.get ()));
649- CHECK_EQ (val.get_den_si (), 1 );
650- subscriptValues.push_back (
651- getLLVMConstantSignedInt64 (val.get_num_si ()));
652- break ;
653- }
654- default :
655- LOG (FATAL) << " NYI" ;
656- }
617+ subscriptValues.push_back (halide_cg.getValue (subscript));
657618 }
658619
659620 auto destAddr = halide_cg.get_builder ().CreateInBoundsGEP (
@@ -703,34 +664,28 @@ IslCodegenRes codegenISL(const Scop& scop) {
703664 const Scop& scop,
704665 StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
705666 auto expr = node.user_get_expr ();
706- // We rename loop-related dimensions manually.
707667 auto schedule = build.get_schedule ();
708- auto scheduleSpace = build.get_schedule_space ();
709668 auto scheduleMap = isl::map::from_union_map (schedule);
710669
711670 auto stmtId = expr.get_op_arg (0 ).get_id ();
712671 // auto nodeId = isl::id(
713672 // node.get_ctx(),
714673 // std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
715674 CHECK_EQ (0 , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
716- CHECK_EQ (
717- scheduleMap.dim (isl::dim_type::out),
718- scheduleSpace.dim (isl::dim_type::set));
719- for (int i = 0 ; i < scheduleSpace.dim (isl::dim_type::set); ++i) {
720- scheduleMap = scheduleMap.set_dim_id (
721- isl::dim_type::out,
722- i,
723- scheduleSpace.get_dim_id (isl::dim_type::set, i));
724- }
725675 auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
726- iteratorMaps.emplace (stmtId, iteratorMap);
676+ auto iterators = scop.halide .iterators .at (stmtId);
677+ auto & stmtIteratorMap = iteratorMaps[stmtId];
678+ for (int i = 0 ; i < iterators.size (); ++i) {
679+ auto expr = build.expr_from (iteratorMap.get_pw_aff (i));
680+ stmtIteratorMap.emplace (iterators[i], expr);
681+ }
727682 auto & subscripts = stmtSubscripts[stmtId];
728683 auto provide =
729684 scop.halide .statements .at (stmtId).as <Halide::Internal::Provide>();
730685 for (auto e : provide->args ) {
731686 const auto & map = iteratorMap;
732- auto space = map.get_space ().range ();
733- auto aff = halide2isl::makeIslAffFromExpr ( space, e);
687+ auto space = map.get_space ().params ();
688+ auto aff = scop. makeIslAffFromStmtExpr (stmtId, space, e);
734689 auto pulled = isl::pw_aff (aff).pullback (map);
735690 CHECK_EQ (pulled.n_piece (), 1 );
736691 subscripts.push_back (build.expr_from (pulled));
0 commit comments