@@ -3322,21 +3322,6 @@ static const char *paramKind2Str(KernelParamKind K) {
33223322#undef CASE
33233323}
33243324
3325- // Removes all "(anonymous namespace)::" substrings from given string, and emits
3326- // it.
3327- static void emitWithoutAnonNamespaces (llvm::raw_ostream &OS, StringRef Source) {
3328- const char S1[] = " (anonymous namespace)::" ;
3329-
3330- size_t Pos;
3331-
3332- while ((Pos = Source.find (S1)) != StringRef::npos) {
3333- OS << Source.take_front (Pos);
3334- Source = Source.drop_front (Pos + sizeof (S1) - 1 );
3335- }
3336-
3337- OS << Source;
3338- }
3339-
33403325// Emits a forward declaration
33413326void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
33423327 SourceLocation KernelLocation) {
@@ -3555,139 +3540,133 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
35553540 }
35563541}
35573542
3558- static void emitCPPTypeString (raw_ostream &OS, QualType Ty) {
3559- LangOptions LO;
3560- PrintingPolicy P (LO);
3561- P.SuppressTypedefs = true ;
3562- emitWithoutAnonNamespaces (OS, Ty.getAsString (P));
3563- }
3564-
3565- static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3566- ArrayRef<TemplateArgument> Args,
3567- const PrintingPolicy &P);
3543+ class SYCLKernelNameTypePrinter
3544+ : public TypeVisitor<SYCLKernelNameTypePrinter>,
3545+ public ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter> {
3546+ using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypePrinter>;
3547+ using InnerTemplArgVisitor =
3548+ ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter>;
3549+ raw_ostream &OS;
3550+ PrintingPolicy &Policy;
3551+
3552+ void printTemplateArgs (ArrayRef<TemplateArgument> Args) {
3553+ for (size_t I = 0 , E = Args.size (); I < E; ++I) {
3554+ const TemplateArgument &Arg = Args[I];
3555+ // If argument is an empty pack argument, skip printing comma and
3556+ // argument.
3557+ if (Arg.getKind () == TemplateArgument::ArgKind::Pack && !Arg.pack_size ())
3558+ continue ;
35683559
3569- static void emitKernelNameType (QualType T, ASTContext &Ctx, raw_ostream &OS,
3570- const PrintingPolicy &TypePolicy) ;
3560+ if (I)
3561+ OS << " , " ;
35713562
3572- static void printArgument (ASTContext &Ctx, raw_ostream &ArgOS,
3573- TemplateArgument Arg, const PrintingPolicy &P) {
3574- switch (Arg.getKind ()) {
3575- case TemplateArgument::ArgKind::Pack: {
3576- printArguments (Ctx, ArgOS, Arg.getPackAsArray (), P);
3577- break ;
3578- }
3579- case TemplateArgument::ArgKind::Integral: {
3580- QualType T = Arg.getIntegralType ();
3581- const EnumType *ET = T->getAs <EnumType>();
3582-
3583- if (ET) {
3584- const llvm::APSInt &Val = Arg.getAsIntegral ();
3585- ArgOS << " static_cast<"
3586- << ET->getDecl ()->getQualifiedNameAsString (
3587- /* WithGlobalNsPrefix*/ true )
3588- << " >"
3589- << " (" << Val << " )" ;
3590- } else {
3591- Arg.print (P, ArgOS);
3563+ Visit (Arg);
35923564 }
3593- break ;
35943565 }
3595- case TemplateArgument::ArgKind::Type: {
3596- LangOptions LO;
3597- PrintingPolicy TypePolicy (LO);
3598- TypePolicy.SuppressTypedefs = true ;
3599- TypePolicy.SuppressTagKeyword = true ;
3600- QualType T = Arg.getAsType ();
36013566
3602- emitKernelNameType (T, Ctx, ArgOS, TypePolicy);
3603- break ;
3604- }
3605- case TemplateArgument::ArgKind::Template: {
3606- TemplateDecl *TD = Arg.getAsTemplate ().getAsTemplateDecl ();
3607- ArgOS << TD->getQualifiedNameAsString ();
3608- break ;
3609- }
3610- default :
3611- Arg.print (P, ArgOS);
3567+ void VisitQualifiers (Qualifiers Quals) {
3568+ Quals.print (OS, Policy, /* appendSpaceIfNotEmpty*/ true );
36123569 }
3613- }
36143570
3615- static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3616- ArrayRef<TemplateArgument> Args,
3617- const PrintingPolicy &P) {
3618- for (unsigned I = 0 ; I < Args.size (); I++) {
3619- const TemplateArgument &Arg = Args[I];
3571+ public:
3572+ SYCLKernelNameTypePrinter (raw_ostream &OS, PrintingPolicy &Policy)
3573+ : OS(OS), Policy(Policy) {}
36203574
3621- // If argument is an empty pack argument, skip printing comma and argument.
3622- if (Arg. getKind () == TemplateArgument::ArgKind::Pack && !Arg. pack_size ())
3623- continue ;
3575+ void Visit (QualType T) {
3576+ if (T. isNull ())
3577+ return ;
36243578
3625- if (I != 0 )
3626- ArgOS << " , " ;
3579+ QualType CT = T. getCanonicalType ();
3580+ VisitQualifiers (CT. getQualifiers ()) ;
36273581
3628- printArgument (Ctx, ArgOS, Arg, P );
3582+ InnerTypeVisitor::Visit (CT. getTypePtr () );
36293583 }
3630- }
36313584
3632- static void printTemplateArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3633- ArrayRef<TemplateArgument> Args,
3634- const PrintingPolicy &P) {
3635- ArgOS << " <" ;
3636- printArguments (Ctx, ArgOS, Args, P);
3637- ArgOS << " >" ;
3638- }
3585+ void VisitType (const Type *T) {
3586+ OS << QualType::getAsString (T, Qualifiers (), Policy);
3587+ }
36393588
3640- static void emitRecordType (raw_ostream &OS, QualType T, const CXXRecordDecl *RD,
3641- const PrintingPolicy &TypePolicy) {
3642- SmallString<64 > Buf;
3643- llvm::raw_svector_ostream RecOS (Buf);
3644- T.getCanonicalType ().getQualifiers ().print (RecOS, TypePolicy,
3645- /* appendSpaceIfNotEmpty*/ true );
3646- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3589+ void Visit (const TemplateArgument &TA) {
3590+ if (TA.isNull ())
3591+ return ;
3592+ InnerTemplArgVisitor::Visit (TA);
3593+ }
3594+
3595+ void VisitTagType (const TagType *T) {
3596+ TagDecl *RD = T->getDecl ();
3597+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
36473598
3648- // Print template class name
3649- TSD->printQualifiedName (RecOS, TypePolicy , /* WithGlobalNsPrefix*/ true );
3599+ // Print template class name
3600+ TSD->printQualifiedName (OS, Policy , /* WithGlobalNsPrefix*/ true );
36503601
3651- // Print template arguments substituting enumerators
3652- ASTContext &Ctx = RD-> getASTContext () ;
3653- const TemplateArgumentList &Args = TSD-> getTemplateArgs ( );
3654- printTemplateArguments (Ctx, RecOS, Args. asArray (), TypePolicy) ;
3602+ ArrayRef<TemplateArgument> Args = TSD-> getTemplateArgs (). asArray ();
3603+ OS << " < " ;
3604+ printTemplateArgs (Args );
3605+ OS << " > " ;
36553606
3656- emitWithoutAnonNamespaces (OS, RecOS.str ());
3657- return ;
3607+ return ;
3608+ }
3609+ // TODO: Next part of code results in printing of "class" keyword before
3610+ // class name in case if kernel name doesn't belong to some namespace. It
3611+ // seems if we don't print it, the integration header still represents valid
3612+ // c++ code. Probably we don't need to print it at all.
3613+ if (RD->getDeclContext ()->isFunctionOrMethod ()) {
3614+ OS << QualType::getAsString (T, Qualifiers (), Policy);
3615+ return ;
3616+ }
3617+
3618+ const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext ());
3619+ RD->printQualifiedName (OS, Policy, !(NS && NS->isAnonymousNamespace ()));
36583620 }
3659- if (RD-> getDeclContext ()-> isFunctionOrMethod ()) {
3660- emitWithoutAnonNamespaces (OS, T. getCanonicalType (). getAsString (TypePolicy));
3661- return ;
3621+
3622+ void VisitTemplateArgument ( const TemplateArgument &TA) {
3623+ TA. print (Policy, OS) ;
36623624 }
36633625
3664- const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext ());
3665- RD->printQualifiedName (RecOS, TypePolicy,
3666- !(NS && NS->isAnonymousNamespace ()));
3667- emitWithoutAnonNamespaces (OS, RecOS.str ());
3668- }
3626+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
3627+ Policy.SuppressTagKeyword = true ;
3628+ QualType T = TA.getAsType ();
3629+ Visit (T);
3630+ Policy.SuppressTagKeyword = false ;
3631+ }
36693632
3670- static void emitKernelNameType (QualType T, ASTContext &Ctx, raw_ostream &OS,
3671- const PrintingPolicy &TypePolicy) {
3672- if (T->isRecordType ()) {
3673- emitRecordType (OS, T, T->getAsCXXRecordDecl (), TypePolicy);
3674- return ;
3633+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
3634+ QualType T = TA.getIntegralType ();
3635+ if (const EnumType *ET = T->getAs <EnumType>()) {
3636+ const llvm::APSInt &Val = TA.getAsIntegral ();
3637+ OS << " static_cast<" ;
3638+ ET->getDecl ()->printQualifiedName (OS, Policy,
3639+ /* WithGlobalNsPrefix*/ true );
3640+ OS << " >(" << Val << " )" ;
3641+ } else {
3642+ TA.print (Policy, OS);
3643+ }
36753644 }
36763645
3677- if (T->isEnumeralType ())
3678- OS << " ::" ;
3679- emitWithoutAnonNamespaces (OS, T.getCanonicalType ().getAsString (TypePolicy));
3680- }
3646+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
3647+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
3648+ TD->printQualifiedName (OS, Policy);
3649+ }
3650+
3651+ void VisitPackTemplateArgument (const TemplateArgument &TA) {
3652+ printTemplateArgs (TA.getPackAsArray ());
3653+ }
3654+ };
36813655
36823656void SYCLIntegrationHeader::emit (raw_ostream &O) {
36833657 O << " // This is auto-generated SYCL integration header.\n " ;
36843658 O << " \n " ;
36853659
3686- O << " #include <CL/sycl/detail/defines .hpp>\n " ;
3660+ O << " #include <CL/sycl/detail/defines_elementary .hpp>\n " ;
36873661 O << " #include <CL/sycl/detail/kernel_desc.hpp>\n " ;
36883662
36893663 O << " \n " ;
36903664
3665+ LangOptions LO;
3666+ PrintingPolicy Policy (LO);
3667+ Policy.SuppressTypedefs = true ;
3668+ Policy.SuppressUnwrittenScope = true ;
3669+
36913670 if (SpecConsts.size () > 0 ) {
36923671 // Remove duplicates.
36933672 std::sort (SpecConsts.begin (), SpecConsts.end (),
@@ -3705,7 +3684,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37053684 O << " // Specialization constants IDs:\n " ;
37063685 for (const auto &P : llvm::make_range (SpecConsts.begin (), End)) {
37073686 O << " template <> struct sycl::detail::SpecConstantInfo<" ;
3708- emitCPPTypeString (O, P.first );
3687+ O << P.first . getAsString (Policy );
37093688 O << " > {\n " ;
37103689 O << " static constexpr const char* getName() {\n " ;
37113690 O << " return \" " << P.second << " \" ;\n " ;
@@ -3773,19 +3752,17 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37733752 O << " ', '" << c;
37743753 O << " '> {\n " ;
37753754 } else {
3776- LangOptions LO;
3777- PrintingPolicy P (LO);
3778- P.SuppressTypedefs = true ;
37793755 O << " template <> struct KernelInfo<" ;
3780- emitKernelNameType (K.NameType , S.getASTContext (), O, P);
3756+ SYCLKernelNameTypePrinter Printer (O, Policy);
3757+ Printer.Visit (K.NameType );
37813758 O << " > {\n " ;
37823759 }
3783- O << " DLL_LOCAL \n " ;
3760+ O << " __SYCL_DLL_LOCAL \n " ;
37843761 O << " static constexpr const char* getName() { return \" " << K.Name
37853762 << " \" ; }\n " ;
3786- O << " DLL_LOCAL \n " ;
3763+ O << " __SYCL_DLL_LOCAL \n " ;
37873764 O << " static constexpr unsigned getNumParams() { return " << N << " ; }\n " ;
3788- O << " DLL_LOCAL \n " ;
3765+ O << " __SYCL_DLL_LOCAL \n " ;
37893766 O << " static constexpr const kernel_param_desc_t& " ;
37903767 O << " getParamDesc(unsigned i) {\n " ;
37913768 O << " return kernel_signatures[i+" << CurStart << " ];\n " ;
0 commit comments