3333
3434#include " Halide.h"
3535
36- #include " isl/ast.h"
37-
3836#include " tc/core/constants.h"
3937#include " tc/core/flags.h"
4038#include " tc/core/halide2isl.h"
@@ -83,12 +81,12 @@ llvm::Value* getLLVMConstantSignedInt64(int64_t v) {
8381}
8482
8583int64_t IslExprToSInt (isl::ast_expr e) {
86- CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_int);
87- return toSInt (isl::manage (isl_ast_expr_get_val (e.get ())));
84+ auto intExpr = e.as <isl::ast_expr_int>();
85+ CHECK (intExpr);
86+ return toSInt (intExpr.get_val ());
8887}
8988
90- int64_t islIdToInt (isl::ast_expr e, isl::set context) {
91- CHECK (isl_ast_expr_get_type (e.get ()) == isl_ast_expr_type::isl_ast_expr_id);
89+ int64_t islIdToInt (isl::ast_expr_id e, isl::set context) {
9290 auto space = context.get_space ();
9391 isl::aff param (isl::aff::param_on_domain_space (space, e.get_id ()));
9492 auto p = context.sample_point ();
@@ -127,22 +125,21 @@ class IslAstExprInterpeter {
127125 IslAstExprInterpeter (isl::set context) : context_(context){};
128126
129127 int64_t interpret (isl::ast_expr e) {
130- switch (isl_ast_expr_get_type (e.get ())) {
131- case isl_ast_expr_type::isl_ast_expr_int:
132- return IslExprToSInt (e);
133- case isl_ast_expr_type::isl_ast_expr_id:
134- return islIdToInt (e, context_);
135- case isl_ast_expr_type::isl_ast_expr_op:
136- return interpretOp (e);
137- default :
138- CHECK (false ) << " NYI" ;
139- return 0 ; // avoid warning
128+ if (auto intExpr = e.as <isl::ast_expr_int>()) {
129+ return IslExprToSInt (intExpr);
130+ } else if (auto idExpr = e.as <isl::ast_expr_id>()) {
131+ return islIdToInt (idExpr, context_);
132+ } else if (auto opExpr = e.as <isl::ast_expr_op>()) {
133+ return interpretOp (opExpr);
134+ } else {
135+ CHECK (false ) << " NYI" ;
136+ return 0 ; // avoid warning
140137 }
141138 };
142139
143140 private:
144- int64_t interpretOp (isl::ast_expr e) {
145- switch (e.get_op_n_arg ()) {
141+ int64_t interpretOp (isl::ast_expr_op e) {
142+ switch (e.get_n_arg ()) {
146143 case 1 :
147144 return interpretUnaryOp (e);
148145 case 2 :
@@ -153,28 +150,26 @@ class IslAstExprInterpeter {
153150 }
154151 }
155152
156- int64_t interpretBinaryOp (isl::ast_expr e) {
157- auto left = interpret (e.get_op_arg (0 ));
158- auto right = interpret (e.get_op_arg (1 ));
159- switch (e.get_op_type ()) {
160- case isl::ast_op_type::add:
161- return left + right;
162- case isl::ast_op_type::sub:
163- return left - right;
164- default :
165- CHECK (false ) << " NYI: " << e;
166- return 0 ; // avoid warning
153+ int64_t interpretBinaryOp (isl::ast_expr_op e) {
154+ auto left = interpret (e.get_arg (0 ));
155+ auto right = interpret (e.get_arg (1 ));
156+ if (e.as <isl::ast_op_add>()) {
157+ return left + right;
158+ } else if (e.as <isl::ast_op_sub>()) {
159+ return left - right;
160+ } else {
161+ CHECK (false ) << " NYI: " << e;
162+ return 0 ; // avoid warning
167163 }
168164 }
169165
170- int64_t interpretUnaryOp (isl::ast_expr e) {
171- auto val = interpret (e.get_op_arg (0 ));
172- switch (e.get_op_type ()) {
173- case isl::ast_op_type::minus:
174- return -val;
175- default :
176- CHECK (false ) << " NYI" ;
177- return 0 ; // avoid warning
166+ int64_t interpretUnaryOp (isl::ast_expr_op e) {
167+ auto val = interpret (e.get_arg (0 ));
168+ if (e.as <isl::ast_op_minus>()) {
169+ return -val;
170+ } else {
171+ CHECK (false ) << " NYI" ;
172+ return 0 ; // avoid warning
178173 }
179174 }
180175};
@@ -301,16 +296,13 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
301296};
302297
303298llvm::Value* CodeGen_TC::getValue (isl::ast_expr expr) {
304- switch (isl_ast_expr_get_type (expr.get ())) {
305- case isl_ast_expr_type::isl_ast_expr_id:
306- return sym_get (expr.get_id ().get_name ());
307- case isl_ast_expr_type::isl_ast_expr_int: {
308- auto val = isl::manage (isl_ast_expr_get_val (expr.get ()));
309- return getLLVMConstantSignedInt64 (toSInt (val));
310- }
311- default :
312- LOG (FATAL) << " NYI" ;
313- return nullptr ;
299+ if (auto idExpr = expr.as <isl::ast_expr_id>()) {
300+ return sym_get (idExpr.get_id ().get_name ());
301+ } else if (auto intExpr = expr.as <isl::ast_expr_int>()) {
302+ return getLLVMConstantSignedInt64 (toSInt (intExpr.get_val ()));
303+ } else {
304+ LOG (FATAL) << " NYI" ;
305+ return nullptr ;
314306 }
315307}
316308
@@ -483,7 +475,7 @@ class LLVMCodegen {
483475 halide_cg.get_builder ().CreateBr (headerBB);
484476
485477 llvm::PHINode* phi = nullptr ;
486- auto iterator = node.get_iterator ().get_id ();
478+ auto iterator = node.get_iterator ().as <isl::ast_expr_id>(). get_id ();
487479
488480 // Loop Header
489481 {
@@ -494,30 +486,25 @@ class LLVMCodegen {
494486 halide_cg.sym_push (iterator.get_name (), phi);
495487 phi->addIncoming (getLLVMConstantSignedInt64 (initVal), incoming);
496488
497- auto cond_expr = node.get_cond ();
498- CHECK (
499- cond_expr.get_op_type () == isl::ast_op_type::lt or
500- cond_expr.get_op_type () == isl::ast_op_type::le)
489+ auto cond_expr = node.get_cond ().as <isl::ast_expr_op>();
490+ CHECK (cond_expr.as <isl::ast_op_lt>() or cond_expr.as <isl::ast_op_le>())
501491 << " I only know how to codegen lt and le" ;
502- auto condLHS = cond_expr.get_op_arg (0 );
503- CHECK (
504- isl_ast_expr_get_type (condLHS.get ()) ==
505- isl_ast_expr_type::isl_ast_expr_id);
492+ auto condLHS = cond_expr.get_arg (0 ).as <isl::ast_expr_id>();
493+ CHECK (condLHS);
506494 CHECK_EQ (condLHS.get_id (), iterator);
507495
508496 IslAstExprInterpeter i (scop_.globalParameterContext );
509- auto condRHSVal = i.interpret (cond_expr.get_op_arg (1 ));
497+ auto condRHSVal = i.interpret (cond_expr.get_arg (1 ));
510498
511499 auto cond = [&]() {
512500 auto constant = getLLVMConstantSignedInt64 (condRHSVal);
513- switch (cond_expr.get_op_type ()) {
514- case isl::ast_op_type::lt:
515- return halide_cg.get_builder ().CreateICmpSLT (phi, constant);
516- case isl::ast_op_type::le:
517- return halide_cg.get_builder ().CreateICmpSLE (phi, constant);
518- default :
519- CHECK (false ) << " NYI" ;
520- return static_cast <llvm::Value*>(nullptr ); // avoid warning
501+ if (cond_expr.as <isl::ast_op_lt>()) {
502+ return halide_cg.get_builder ().CreateICmpSLT (phi, constant);
503+ } else if (cond_expr.as <isl::ast_op_le>()) {
504+ return halide_cg.get_builder ().CreateICmpSLE (phi, constant);
505+ } else {
506+ CHECK (false ) << " NYI" ;
507+ return static_cast <llvm::Value*>(nullptr ); // avoid warning
521508 }
522509 }();
523510 halide_cg.get_builder ().CreateCondBr (cond, loopBodyBB, loopExitBB);
@@ -572,8 +559,8 @@ class LLVMCodegen {
572559 }
573560
574561 llvm::BasicBlock* emitStmt (isl::ast_node_user node) {
575- isl::ast_expr usrExp = node.get_expr ();
576- auto id = usrExp.get_op_arg ( 0 ).get_id ();
562+ isl::ast_expr_op usrExp = node.get_expr (). as <isl::ast_expr_op> ();
563+ auto id = usrExp.get_arg ( 0 ). as <isl::ast_expr_id>( ).get_id ();
577564 auto provide = scop_.halide .statements .at (id);
578565 auto op = provide.as <Halide::Internal::Provide>();
579566 CHECK (op) << " Expected a Provide node: " << provide << ' \n ' ;
@@ -632,11 +619,11 @@ IslCodegenRes codegenISL(const Scop& scop) {
632619 StmtSubscriptExprMapType& stmtSubscripts) -> isl::ast_node {
633620 auto user = node.as <isl::ast_node_user>();
634621 CHECK (user);
635- auto expr = user.get_expr ();
622+ auto expr = user.get_expr (). as <isl::ast_expr_op>() ;
636623 auto schedule = build.get_schedule ();
637624 auto scheduleMap = isl::map::from_union_map (schedule);
638625
639- auto stmtId = expr.get_op_arg ( 0 ).get_id ();
626+ auto stmtId = expr.get_arg ( 0 ). as <isl::ast_expr_id>( ).get_id ();
640627 CHECK_EQ (0u , iteratorMaps.count (stmtId)) << " entry exists: " << stmtId;
641628 auto iteratorMap = isl::pw_multi_aff (scheduleMap.reverse ());
642629 auto iterators = scop.halide .iterators .at (stmtId);
0 commit comments