3636#include " isl/ast.h"
3737
3838#include " tc/core/constants.h"
39- // #include "tc/core/polyhedral/isl_mu_wrappers.h"
4039#include " tc/core/flags.h"
4140#include " tc/core/polyhedral/codegen.h"
4241#include " tc/core/polyhedral/schedule_isl_conversion.h"
4342#include " tc/core/polyhedral/scop.h"
4443#include " tc/core/scope_guard.h"
44+ #include " tc/external/isl.h"
4545
4646#ifndef LLVM_VERSION_MAJOR
4747#error LLVM_VERSION_MAJOR not set
@@ -76,10 +76,9 @@ namespace {
7676thread_local llvm::LLVMContext llvmCtx;
7777
7878int64_t toSInt (isl::val v) {
79- auto n = v.get_num_si ();
80- auto d = v.get_den_si ();
81- CHECK_EQ (n % d, 0 );
82- return n / d;
79+ CHECK (v.is_int ());
80+ static_assert (sizeof (long ) <= 8 , " long is assumed to fit into 64bits" );
81+ return v.get_num_si ();
8382}
8483
8584llvm::Value* getLLVMConstantSignedInt64 (int64_t v) {
@@ -88,25 +87,16 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
8887
8988int64_t IslExprToSInt (isl::ast_expr e) {
9089 CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
91- assert (sizeof (long ) <= 8 ); // long is assumed to fit to 64bits
9290 return toSInt (isl::manage (isl_ast_expr_get_val (e.get ())));
9391}
9492
9593int64_t islIdToInt (isl::ast_expr e, isl::set context) {
9694 CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_id);
97- CHECK_NE (-1 , context.find_dim_by_id (isl::dim_type::param, e.get_id ()));
98- while (context.dim (isl::dim_type::param) > 1 ) {
99- for (unsigned int d = 0 ; d < context.dim (isl::dim_type::param); ++d) {
100- if (d == context.find_dim_by_id (isl::dim_type::param, e.get_id ())) {
101- continue ;
102- }
103- context = context.remove_dims (isl::dim_type::param, d, 1 );
104- }
105- }
95+ auto space = context.get_space ();
96+ isl::aff param (isl::aff::param_on_domain_space (space, e.get_id ()));
10697 auto p = context.sample_point ();
107-
108- auto val = toSInt (p.get_coordinate_val (isl::dim_type::param, 0 ));
109- return val;
98+ CHECK (context.is_equal (p));
99+ return toSInt (param.eval (p));
110100}
111101
112102int64_t getTensorSize (isl::set context, const Halide::Expr& e) {
@@ -319,8 +309,7 @@ llvm::Value* CodeGen_TC::getValue(isl::ast_expr expr) {
319309 return sym_get (expr.get_id ().get_name ());
320310 case isl_ast_expr_type::isl_ast_expr_int: {
321311 auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
322- CHECK (val.is_int ());
323- return getLLVMConstantSignedInt64 (val.get_num_si ());
312+ return getLLVMConstantSignedInt64 (toSInt (val));
324313 }
325314 default :
326315 LOG (FATAL) << " NYI" ;
@@ -497,16 +486,15 @@ class LLVMCodegen {
497486 halide_cg.get_builder ().CreateBr (headerBB);
498487
499488 llvm::PHINode* phi = nullptr ;
489+ auto iterator = node.get_iterator ().get_id ();
500490
501491 // Loop Header
502492 {
503493 auto initVal = IslExprToSInt (node.get_init ());
504494 halide_cg.get_builder ().SetInsertPoint (headerBB);
505495 phi = halide_cg.get_builder ().CreatePHI (
506- llvm::Type::getInt64Ty (llvmCtx),
507- 2 ,
508- node.get_iterator ().get_id ().get_name ());
509- halide_cg.sym_push (node.get_iterator ().get_id ().get_name (), phi);
496+ llvm::Type::getInt64Ty (llvmCtx), 2 , iterator.get_name ());
497+ halide_cg.sym_push (iterator.get_name (), phi);
510498 phi->addIncoming (getLLVMConstantSignedInt64 (initVal), incoming);
511499
512500 auto cond_expr = node.get_cond ();
@@ -518,7 +506,7 @@ class LLVMCodegen {
518506 CHECK (
519507 isl_ast_expr_get_type (condLHS.get ()) ==
520508 isl_ast_expr_type::isl_ast_expr_id);
521- CHECK_EQ (condLHS.get_id (), node. get_iterator (). get_id () );
509+ CHECK_EQ (condLHS.get_id (), iterator );
522510
523511 IslAstExprInterpeter i (scop_.globalParameterContext );
524512 auto condRHSVal = i.interpret (cond_expr.get_op_arg (1 ));
@@ -575,7 +563,7 @@ class LLVMCodegen {
575563 }
576564
577565 halide_cg.get_builder ().SetInsertPoint (loopExitBB);
578- halide_cg.sym_pop (node. get_iterator (). get_id () .get_name ());
566+ halide_cg.sym_pop (iterator .get_name ());
579567#ifdef TAPIR_VERSION_MAJOR
580568 if (parallel) {
581569 auto * syncBB = llvm::BasicBlock::Create (llvmCtx, " synced" , function);
@@ -652,9 +640,6 @@ IslCodegenRes codegenISL(const Scop& scop) {
652640 auto scheduleMap = isl::map::from_union_map (schedule);
653641
654642 auto stmtId = expr.get_op_arg (0 ).get_id ();
655- // auto nodeId = isl::id(
656- // node.get_ctx(),
657- // std::string(kAstNodeIdPrefix) + std::to_string(nAstNodes()++));
658643 CHECK_EQ (0 , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
659644 auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
660645 auto iterators = scop.halide .iterators .at (stmtId);
0 commit comments