@@ -396,10 +396,49 @@ struct LdgWrapper {
396396 std::ostream& out_;
397397};
398398
399+ template <typename AFF>
400+ isl::ast_expr buildAccess (AFF access, const CodegenStatementContext& context) {
401+ return context.build ().access_from (access);
402+ }
403+
404+ void emitAccess (isl::ast_expr access, const CodegenStatementContext& context) {
405+ context.ss << access.to_C_str ();
406+ }
407+
399408template <typename AFF>
400409void emitAccess (AFF access, const CodegenStatementContext& context) {
410+ emitAccess (buildAccess (access, context), context);
411+ }
412+
413+ // Check that the given expression is an access with constant index expressions
414+ void checkConstantAccess (isl::ast_expr expr) {
415+ auto op = expr.as <isl::ast_expr_op>();
416+ auto access = op.as <isl::ast_op_access>();
417+ TC_CHECK (access);
418+ for (int i = 1 ; i < access.get_n_arg (); ++i) {
419+ auto arg = access.get_arg (i);
420+ TC_CHECK (arg.as <isl::ast_expr_int>())
421+ << " expected constant subscript, got " << arg.to_C_str ();
422+ }
423+ }
424+
425+ // Print an access to a(n array of) register(s), checking that
426+ // the index expressions are constant.
427+ void emitRegisterAccess (
428+ isl::pw_multi_aff access,
429+ const CodegenStatementContext& context) {
430+ auto expr = buildAccess (access, context);
431+ checkConstantAccess (expr);
432+ emitAccess (expr, context);
433+ }
434+
435+ // Print an access to global memory, wrapping the access in an "__ldg()"
436+ // call if the accessed tensor is known to be read-only.
437+ void emitGlobalAccess (
438+ isl::multi_pw_aff access,
439+ const CodegenStatementContext& context) {
401440 LdgWrapper ldgWrapper (context, access.get_tuple_id (isl::dim_type::out));
402- context. ss << context. build (). access_from ( access). to_C_str ( );
441+ emitAccess ( access, context );
403442}
404443} // namespace
405444
@@ -414,9 +453,9 @@ void emitCopyStmt(const CodegenStatementContext& context) {
414453 if (isRead) {
415454 emitAccess (isl::multi_pw_aff (promoted), context);
416455 context.ss << " = " ;
417- emitAccess (isl::multi_pw_aff (original), context);
456+ emitGlobalAccess (isl::multi_pw_aff (original), context);
418457 } else {
419- emitAccess (isl::multi_pw_aff (original), context);
458+ emitGlobalAccess (isl::multi_pw_aff (original), context);
420459 context.ss << " = " ;
421460 emitAccess (isl::multi_pw_aff (promoted), context);
422461 }
@@ -625,7 +664,8 @@ void emitMappedTensorAccess(
625664 return ;
626665 }
627666
628- auto tensorId = context.scop ().promotedDecl (promotionInfo.groupId ).tensorId ;
667+ auto decl = context.scop ().promotedDecl (promotionInfo.groupId );
668+ auto tensorId = decl.tensorId ;
629669
630670 // Here and below in comments: D = domain, O = original tensor, P = promoted
631671 // tensor, S = partial schedule, A = AST loops;
@@ -651,7 +691,11 @@ void emitMappedTensorAccess(
651691 auto astToPromoted =
652692 isl::pw_multi_aff (promotion).pullback (astToScheduledOriginal);
653693
654- emitAccess (astToPromoted, context);
694+ if (decl.kind == Scop::PromotedDecl::Kind::Register) {
695+ emitRegisterAccess (astToPromoted, context);
696+ } else {
697+ emitAccess (astToPromoted, context);
698+ }
655699}
656700
657701} // namespace detail
0 commit comments