@@ -359,28 +359,33 @@ void emitReductionInit(
359359 context.ss << " ;" << endl;
360360}
361361
362- void emitCopyStmt (const CodegenStatementContext& context) {
363- using detail::emitDirectSubscripts;
362+ namespace {
363+ template <typename AFF>
364+ void emitAccess (AFF access, const CodegenStatementContext& context) {
365+ // Use a temporary isl::ast_build to print the expression.
366+ // Ideally, this should use the build at the point
367+ // where the user statement was created.
368+ auto astBuild = isl::ast_build::from_context (access.domain ());
369+ context.ss << astBuild.access_from (access).to_C_str ();
370+ }
371+ } // namespace
364372
373+ void emitCopyStmt (const CodegenStatementContext& context) {
365374 auto stmtId = context.statementId ();
366375
367376 auto iteratorMap = context.iteratorMap ();
368377 auto promoted = iteratorMap.range_factor_range ();
369378 auto original = iteratorMap.range_factor_domain ().range_factor_range ();
370379 auto isRead = stmtId.get_name () == kReadIdName ;
371- auto originalName = original.get_tuple_id (isl::dim_type::out).get_name ();
372- auto promotedName = promoted.get_tuple_id (isl::dim_type::out).get_name ();
373380
374381 if (isRead) {
375- context.ss << promotedName;
376- emitDirectSubscripts (promoted, context);
377- context.ss << " = " << originalName;
378- emitDirectSubscripts (original, context);
382+ emitAccess (isl::multi_pw_aff (promoted), context);
383+ context.ss << " = " ;
384+ emitAccess (isl::multi_pw_aff (original), context);
379385 } else {
380- context.ss << originalName;
381- emitDirectSubscripts (original, context);
382- context.ss << " = " << promotedName;
383- emitDirectSubscripts (promoted, context);
386+ emitAccess (isl::multi_pw_aff (original), context);
387+ context.ss << " = " ;
388+ emitAccess (isl::multi_pw_aff (promoted), context);
384389 }
385390 context.ss << " ;" << std::endl;
386391}
@@ -447,14 +452,6 @@ void AstPrinter::emitAst(isl::ast_node node) {
447452
448453namespace detail {
449454
450- std::string toString (isl::pw_aff subscript) {
451- // Use a temporary isl::ast_build to print the expression.
452- // Ideally, this should use the build at the point
453- // where the user statement was created.
454- auto astBuild = isl::ast_build::from_context (subscript.domain ());
455- return astBuild.expr_from (subscript).to_C_str ();
456- }
457-
458455isl::pw_aff makeAffFromMappedExpr (
459456 const Halide::Expr& expr,
460457 const CodegenStatementContext& context) {
@@ -498,18 +495,35 @@ isl::multi_aff makeMultiAffAccess(
498495 return ma;
499496}
500497
498+ namespace {
499+ bool is_identifier_or_nonnegative_integer (isl::ast_expr expr) {
500+ if (isl_ast_expr_get_type (expr.get ()) == isl_ast_expr_id)
501+ return true ;
502+ if (isl_ast_expr_get_type (expr.get ()) != isl_ast_expr_int)
503+ return false ;
504+ return isl::manage (isl_ast_expr_get_val (expr.get ())).is_nonneg ();
505+ }
506+ } // namespace
507+
501508void emitHalideExpr (
502509 const Halide::Expr& e,
503510 const CodegenStatementContext& context,
504511 const map<string, string>& substitutions) {
505512 class EmitHalide : public Halide ::Internal::IRPrinter {
506513 using Halide::Internal::IRPrinter::visit;
507514 void visit (const Halide::Internal::Variable* op) {
508- // This is probably needlessly indirect, given that we just have
509- // a name to look up somewhere.
510515 auto pwAff = tc::polyhedral::detail::makeAffFromMappedExpr (
511516 Halide::Expr (op), context);
512- context.ss << tc::polyhedral::detail::toString (pwAff);
517+ // Use a temporary isl::ast_build to print the expression.
518+ // Ideally, this should use the build at the point
519+ // where the user statement was created.
520+ auto astBuild = isl::ast_build::from_context (pwAff.domain ());
521+ auto expr = astBuild.expr_from (pwAff);
522+ auto s = expr.to_C_str ();
523+ if (!is_identifier_or_nonnegative_integer (expr)) {
524+ s = " (" + s + " )" ;
525+ }
526+ context.ss << s;
513527 }
514528 void visit (const Halide::Internal::Call* op) {
515529 if (substitutions.count (op->name )) {
@@ -613,19 +627,7 @@ void emitMappedTensorAccess(
613627 auto astToPromoted =
614628 isl::pw_multi_aff (promotion).pullback (astToScheduledOriginal);
615629
616- auto astBuild = isl::ast_build::from_context (astToPromoted.domain ());
617- context.ss << astBuild.access_from (astToPromoted).to_C_str ();
618- }
619-
620- void emitDirectSubscripts (
621- isl::pw_multi_aff subscripts,
622- const CodegenStatementContext& context) {
623- auto mpa = isl::multi_pw_aff (subscripts); // this conversion is safe
624- for (auto pa : isl::MPA (mpa)) {
625- context.ss << " [" ;
626- context.ss << toString (pa.pa );
627- context.ss << " ]" ;
628- }
630+ emitAccess (astToPromoted, context);
629631}
630632
631633} // namespace detail
0 commit comments