From e512f79b89aaf2dcf4b168324d55b924aa24e54b Mon Sep 17 00:00:00 2001 From: "Podchishchaeva, Mariya" Date: Tue, 24 Feb 2026 14:03:17 -0800 Subject: [PATCH] [clang][SYCL] Decomposition support WIP --- clang/include/clang/AST/ASTNodeTraverser.h | 4 +- clang/include/clang/AST/RecursiveASTVisitor.h | 1 + clang/include/clang/AST/StmtSYCL.h | 26 +- clang/include/clang/Basic/Attr.td | 7 + clang/include/clang/Sema/ScopeInfo.h | 4 + clang/include/clang/Sema/SemaSYCL.h | 9 +- clang/lib/Sema/SemaDecl.cpp | 18 +- clang/lib/Sema/SemaDeclAttr.cpp | 3 + clang/lib/Sema/SemaSYCL.cpp | 485 ++++++++++++++++-- clang/lib/Sema/TreeTransform.h | 7 +- .../ast-dump-sycl-kernel-call-stmt.cpp | 3 + .../ast-dump-sycl-kernel-decomposition.cpp | 141 +++++ .../ast-dump-sycl-kernel-entry-point.cpp | 3 + .../ASTSYCL/ast-print-sycl-kernel-call.cpp | 3 + clang/test/CodeGenSYCL/function-attrs.cpp | 3 + .../CodeGenSYCL/kernel-arg-decomposition.cpp | 96 ++++ .../CodeGenSYCL/kernel-caller-entry-point.cpp | 3 + .../sycl-kernel-entry-point-exceptions.cpp | 3 + .../unique_stable_name_windows_diff.cpp | 3 + ...-kernel-entry-point-attr-appertainment.cpp | 3 + ...kernel-entry-point-attr-device-odr-use.cpp | 3 + .../sycl-kernel-entry-point-attr-grammar.cpp | 3 + ...el-entry-point-attr-kernel-name-module.cpp | 3 + ...ernel-entry-point-attr-kernel-name-pch.cpp | 3 + ...cl-kernel-entry-point-attr-kernel-name.cpp | 3 + .../sycl-kernel-entry-point-attr-sfinae.cpp | 3 + .../sycl-kernel-entry-point-attr-this.cpp | 3 + .../SemaSYCL/sycl-kernel-launch-ms-compat.cpp | 3 + clang/test/SemaSYCL/sycl-kernel-launch.cpp | 3 + 29 files changed, 802 insertions(+), 50 deletions(-) create mode 100644 clang/test/ASTSYCL/ast-dump-sycl-kernel-decomposition.cpp create mode 100644 clang/test/CodeGenSYCL/kernel-arg-decomposition.cpp diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index 5e9463d54747d..c4bafc2017609 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -858,8 +858,10 @@ class ASTNodeTraverser void VisitUnresolvedSYCLKernelCallStmt(const UnresolvedSYCLKernelCallStmt *Node) { Visit(Node->getOriginalStmt()); - if (Traversal != TK_IgnoreUnlessSpelledInSource) + if (Traversal != TK_IgnoreUnlessSpelledInSource) { Visit(Node->getKernelLaunchIdExpr()); + Visit(Node->getSpecArgsIdExpr()); + } } void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index b5be0910194bd..fdb828a7cb680 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3034,6 +3034,7 @@ DEF_TRAVERSE_STMT(UnresolvedSYCLKernelCallStmt, { if (getDerived().shouldVisitImplicitCode()) { TRY_TO(TraverseStmt(S->getOriginalStmt())); TRY_TO(TraverseStmt(S->getKernelLaunchIdExpr())); + TRY_TO(TraverseStmt(S->getSpecArgsIdExpr())); ShouldVisitChildren = false; } }) diff --git a/clang/include/clang/AST/StmtSYCL.h b/clang/include/clang/AST/StmtSYCL.h index 79ac88532e143..cd682f4cea594 100644 --- a/clang/include/clang/AST/StmtSYCL.h +++ b/clang/include/clang/AST/StmtSYCL.h @@ -105,12 +105,19 @@ class UnresolvedSYCLKernelCallStmt : public Stmt { Stmt *OriginalStmt = nullptr; // KernelLaunchIdExpr stores an UnresolvedLookupExpr or UnresolvedMemberExpr // corresponding to the SYCL kernel launch function for which a call - // will be synthesized during template instantiation. + // will be synthesized during template instantiation of the host code. Expr *KernelLaunchIdExpr = nullptr; - - UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr) + // Similar to KernelLaunchIdExpr HandleSYCLSpecialParamsIdExpr stores an + // UnresolvedLookupExpr or UnresolvedMemberExpr corresponding to the fuction + // handling of special SYCL kernel parameters for which a call will be + // synthesized during template instantiation of the device code. + Expr *HandleSYCLSpecialParamsIdExpr = nullptr; + + UnresolvedSYCLKernelCallStmt(CompoundStmt *CS, Expr *IdExpr, + Expr *HandleSYCLSpecialParamsIdExpr) : Stmt(UnresolvedSYCLKernelCallStmtClass), OriginalStmt(CS), - KernelLaunchIdExpr(IdExpr) {} + KernelLaunchIdExpr(IdExpr), + HandleSYCLSpecialParamsIdExpr(HandleSYCLSpecialParamsIdExpr) {} void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; } @@ -118,12 +125,13 @@ class UnresolvedSYCLKernelCallStmt : public Stmt { public: static UnresolvedSYCLKernelCallStmt *Create(const ASTContext &C, - CompoundStmt *CS, Expr *IdExpr) { - return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr); + CompoundStmt *CS, Expr *IdExpr, + Expr *SpecArgsExpr) { + return new (C) UnresolvedSYCLKernelCallStmt(CS, IdExpr, SpecArgsExpr); } static UnresolvedSYCLKernelCallStmt *CreateEmpty(const ASTContext &C) { - return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr); + return new (C) UnresolvedSYCLKernelCallStmt(nullptr, nullptr, nullptr); } CompoundStmt *getOriginalStmt() { return cast(OriginalStmt); } @@ -133,6 +141,10 @@ class UnresolvedSYCLKernelCallStmt : public Stmt { Expr *getKernelLaunchIdExpr() { return KernelLaunchIdExpr; } const Expr *getKernelLaunchIdExpr() const { return KernelLaunchIdExpr; } + Expr *getSpecArgsIdExpr() { return HandleSYCLSpecialParamsIdExpr; } + const Expr *getSpecArgsIdExpr() const { + return HandleSYCLSpecialParamsIdExpr; + } SourceLocation getBeginLoc() const LLVM_READONLY { return getOriginalStmt()->getBeginLoc(); diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 70b5773f95b08..5f991590638f1 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1744,6 +1744,13 @@ def SYCLSpecialClass: InheritableAttr { let Documentation = [SYCLSpecialClassDocs]; } +def SYCLSpecialKernelParameter : InheritableAttr { + let Spellings = [CXX11<"clang", "sycl_special_kernel_parameter">]; + let Subjects = SubjectList<[CXXRecord]>; + let LangOpts = [SYCLHost, SYCLDevice]; + let Documentation = [Undocumented]; +} + def C11NoReturn : InheritableAttr { let Spellings = [CustomKeyword<"_Noreturn">]; let Subjects = SubjectList<[Function], ErrorDiag>; diff --git a/clang/include/clang/Sema/ScopeInfo.h b/clang/include/clang/Sema/ScopeInfo.h index f334f58ebd0a7..1e76e3c676385 100644 --- a/clang/include/clang/Sema/ScopeInfo.h +++ b/clang/include/clang/Sema/ScopeInfo.h @@ -249,6 +249,10 @@ class FunctionScopeInfo { /// to a SYCL kernel launch function in a dependent context. Expr *SYCLKernelLaunchIdExpr = nullptr; + /// An unresolved identifier lookup expression for an implicit call + /// to a handling function for SYCL kernel special parameters. + Expr *HandleSYCLSpecialParamsIdExpr = nullptr; + public: /// Represents a simple identification of a weak object. /// diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h index 4980aa44c3012..268f31d8947cb 100644 --- a/clang/include/clang/Sema/SemaSYCL.h +++ b/clang/include/clang/Sema/SemaSYCL.h @@ -83,19 +83,22 @@ class SemaSYCL : public SemaBase { /// passed as the 'LaunchIdExpr' argument in a call to either /// BuildSYCLKernelCallStmt() or BuildUnresolvedSYCLKernelCallStmt() after /// the function body has been parsed. - ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType KernelName); + ExprResult BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, QualType KernelName, + StringRef FuncName); /// Builds a SYCLKernelCallStmt to wrap 'Body' and to be used as the body of /// 'FD'. 'LaunchIdExpr' specifies the lookup result returned by a previous /// call to BuildSYCLKernelLaunchIdExpr(). StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body, - Expr *LaunchIdExpr); + Expr *LaunchIdExpr, + Expr *HandleSpecParamsExpr); /// Builds an UnresolvedSYCLKernelCallStmt to wrap 'Body'. 'LaunchIdExpr' /// specifies the lookup result returned by a previous call to /// BuildSYCLKernelLaunchIdExpr(). StmtResult BuildUnresolvedSYCLKernelCallStmt(CompoundStmt *Body, - Expr *LaunchIdExpr); + Expr *LaunchIdExpr, + Expr *HandleSpecParamsExpr); }; } // namespace clang diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index eb5b6d65b4d58..744b120a775a0 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -16466,7 +16466,7 @@ Decl *Sema::ActOnStartOfFunctionDef(Scope *FnBodyScope, Decl *D, const auto *SKEPAttr = FD->getAttr(); if (!SKEPAttr->isInvalidAttr()) { ExprResult LaunchIdExpr = - SYCL().BuildSYCLKernelLaunchIdExpr(FD, SKEPAttr->getKernelName()); + SYCL().BuildSYCLKernelLaunchIdExpr(FD, SKEPAttr->getKernelName(), "sycl_kernel_launch"); // Do not mark 'FD' as invalid if construction of `LaunchIDExpr` produces // an invalid result. Name lookup failure for 'sycl_kernel_launch' is // treated as an error in the definition of 'FD'; treating it as an error @@ -16475,6 +16475,13 @@ Decl *Sema::ActOnStartOfFunctionDef(Scope *FnBodyScope, Decl *D, // 'LaunchIDExpr' failed, then 'SYCLKernelLaunchIdExpr' will be assigned // a null pointer value below; that is expected. getCurFunction()->SYCLKernelLaunchIdExpr = LaunchIdExpr.get(); + if (!LaunchIdExpr.isInvalid() && + !LaunchIdExpr.get()->getType()->isVoidType()) { + ExprResult HSPSPIdExpr = SYCL().BuildSYCLKernelLaunchIdExpr( + FD, SKEPAttr->getKernelName(), + "sycl_handle_special_kernel_parameters"); + getCurFunction()->HandleSYCLSpecialParamsIdExpr = HSPSPIdExpr.get(); + } } } @@ -16690,7 +16697,8 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, bool IsInstantiation, // The function body should already be a SYCLKernelCallStmt in this // case, but might not be if there were previous errors. SR = Body; - } else if (!getCurFunction()->SYCLKernelLaunchIdExpr) { + } else if (!getCurFunction()->SYCLKernelLaunchIdExpr || + !getCurFunction()->HandleSYCLSpecialParamsIdExpr) { // If name lookup for a template named sycl_kernel_launch failed // earlier, don't try to build a SYCL kernel call statement as that // would cause additional errors to be issued; just proceed with the @@ -16698,11 +16706,13 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, bool IsInstantiation, SR = Body; } else if (FD->isTemplated()) { SR = SYCL().BuildUnresolvedSYCLKernelCallStmt( - cast(Body), getCurFunction()->SYCLKernelLaunchIdExpr); + cast(Body), getCurFunction()->SYCLKernelLaunchIdExpr, + getCurFunction()->HandleSYCLSpecialParamsIdExpr); } else { SR = SYCL().BuildSYCLKernelCallStmt( FD, cast(Body), - getCurFunction()->SYCLKernelLaunchIdExpr); + getCurFunction()->SYCLKernelLaunchIdExpr, + getCurFunction()->HandleSYCLSpecialParamsIdExpr); } // If construction of the replacement body fails, just continue with the // original function body. An early error return here is not valid; the diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 386651fa691e0..03f86ec1bf480 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7753,6 +7753,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_SYCLSpecialClass: handleSimpleAttribute(S, D, AL); break; + case ParsedAttr::AT_SYCLSpecialKernelParameter: + handleSimpleAttribute(S, D, AL); + break; case ParsedAttr::AT_Format: handleFormatAttr(S, D, AL); break; diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 112a6e4416df2..b3de40cfdd68c 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -425,7 +425,8 @@ void SemaSYCL::CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD) { } ExprResult SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, - QualType KNT) { + QualType KNT, + StringRef FuncName) { // The current context must be the function definition context to ensure // that name lookup is performed within the correct scope. assert(SemaRef.CurContext == FD && "The current declaration context does not " @@ -440,12 +441,13 @@ ExprResult SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, ASTContext &Ctx = SemaRef.getASTContext(); IdentifierInfo &SYCLKernelLaunchID = - Ctx.Idents.get("sycl_kernel_launch", tok::TokenKind::identifier); + Ctx.Idents.get(FuncName, tok::TokenKind::identifier); // Establish a code synthesis context for the implicit name lookup of // a template named 'sycl_kernel_launch'. In the event of an error, this // ensures an appropriate diagnostic note is issued to explain why the // lookup was performed. + // FIXME: Extend diagnostics for handle special parameters function Sema::CodeSynthesisContext CSC; CSC.Kind = Sema::CodeSynthesisContext::SYCLKernelLaunchLookup; CSC.Entity = FD; @@ -492,16 +494,316 @@ ExprResult SemaSYCL::BuildSYCLKernelLaunchIdExpr(FunctionDecl *FD, return IdExpr; } +static bool isSyclSpecialType(QualType Ty) { + if (const auto *RT = Ty->getAsRecordDecl()) + return RT->getMostRecentDecl()->hasAttr(); + return false; +} + namespace { +/// A special visitor to visit subobjects within a type, i.e. fields of a +/// class or elements of an array. Useful for SYCl because in SYCL kernels are +/// defined via lambda expressions or named callable objects and kernel +/// parameters are fields of these. These visitors will be used for diagnosing +/// invalid kernel arugments as well as for functional transformations. +class SubobjectVisitor { + ASTContext &Ctx; + + // These enable handler execution only when previous Handlers succeed. + template + bool handleField(FieldDecl *FD, QualType FDTy, Tn &&...tn) { + bool result = true; + (void)std::initializer_list{(result = result && tn(FD, FDTy), 0)...}; + return result; + } + template + bool handleField(const CXXBaseSpecifier &BD, QualType BDTy, Tn &&...tn) { + bool result = true; + std::initializer_list{(result = result && tn(BD, BDTy), 0)...}; + return result; + } + +#define KF_FOR_EACH(FUNC, Item, Qt) \ + handleField(Item, Qt, ([&](FieldDecl *FD, QualType FDTy) { \ + return Handlers.FUNC(FD, FDTy); \ + })...) + + // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter + // the Wrapper structure that we're currently visiting. Owner is the parent + // type (which doesn't exist in cases where it is a FieldDecl in the + // 'root'), and Wrapper is the current struct being unwrapped. + template + void visitComplexRecord(const CXXRecordDecl *Owner, ParentTy &Parent, + const CXXRecordDecl *Wrapper, QualType RecordTy, + HandlerTys &...Handlers) { + (void)std::initializer_list{ + (Handlers.enterStruct(Owner, Parent, RecordTy), 0)...}; + visitRecordHelper(Wrapper, Wrapper->bases(), Handlers...); + visitRecordHelper(Wrapper, Wrapper->fields(), Handlers...); + (void)std::initializer_list{ + (Handlers.leaveStruct(Owner, Parent, RecordTy), 0)...}; + } + + template + void visitArray(const CXXRecordDecl *Owner, FieldDecl *Field, + QualType ArrayTy, HandlerTys &...Handlers) { + // TODO add support for simple array visiting, i.e. without entering array + // elements. + visitComplexArray(Owner, Field, ArrayTy, Handlers...); + } + + template + void visitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, + const CXXRecordDecl *Wrapper, QualType RecordTy, + HandlerTys &...Handlers) { + // TODO add support for simple record visiting, i.e. without entering record + // fields. + visitComplexRecord(Owner, Parent, Wrapper, RecordTy, Handlers...); + } + + template + void visitRecordHelper(const CXXRecordDecl *Owner, + clang::CXXRecordDecl::base_class_const_range Range, + HandlerTys &...Handlers) { + for (const auto &Base : Range) { + QualType BaseTy = Base.getType(); + visitRecord(Owner, Base, BaseTy->getAsCXXRecordDecl(), BaseTy, + Handlers...); + } + } + + template + void visitRecordHelper(const CXXRecordDecl *Owner, RecordDecl::field_range, + HandlerTys &...Handlers) { + visitRecordFields(Owner, Handlers...); + } + + template + void visitArrayElementImpl(const CXXRecordDecl *Owner, FieldDecl *ArrayField, + QualType ElementTy, uint64_t Index, + HandlerTys &...Handlers) { + visitField(Owner, ArrayField, ElementTy, Handlers...); + } + + template + void visitNthArrayElement(const CXXRecordDecl *Owner, FieldDecl *ArrayField, + QualType ElementTy, uint64_t Index, + HandlerTys &...Handlers) { + visitArrayElementImpl(Owner, ArrayField, ElementTy, Index, Handlers...); + } + + template + void visitComplexArray(const CXXRecordDecl *Owner, FieldDecl *Field, + QualType ArrayTy, HandlerTys &...Handlers) { + // Array workflow is: + // handleArrayType + // enterArray + // visitField (same as before, note that The FieldDecl is the of array + // itself, not the element) + // ... repeat per element, opt-out for duplicates. + // leaveArray + + if (!KF_FOR_EACH(handleArrayType, Field, ArrayTy)) + return; + + const ConstantArrayType *CAT = Ctx.getAsConstantArrayType(ArrayTy); + assert(CAT && "Should only be called on constant-size array."); + QualType ET = CAT->getElementType(); + uint64_t ElemCount = CAT->getSize().getZExtValue(); + + (void)std::initializer_list{ + (Handlers.enterArray(Field, ArrayTy, ET), 0)...}; + + for (uint64_t Index = 0; Index < ElemCount; ++Index) + visitNthArrayElement(Owner, Field, ET, Index, Handlers...); + + (void)std::initializer_list{ + (Handlers.leaveArray(Field, ArrayTy, ET), 0)...}; + } + + template + void visitField(const CXXRecordDecl *Owner, FieldDecl *Field, + QualType FieldTy, HandlerTys &...Handlers) { + if (FieldTy->isStructureOrClassType()) { + if (KF_FOR_EACH(handleStructType, Field, FieldTy)) { + CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); + visitRecord(Owner, Field, RD, FieldTy, Handlers...); + } + } else if (FieldTy->isUnionType()) + KF_FOR_EACH(handleUnionType, Field, FieldTy); + else if (FieldTy->isReferenceType()) + KF_FOR_EACH(handleReferenceType, Field, FieldTy); + else if (FieldTy->isPointerType()) + KF_FOR_EACH(handlePointerType, Field, FieldTy); + else if (FieldTy->isArrayType()) + visitArray(Owner, Field, FieldTy, Handlers...); + else if (FieldTy->isScalarType() || FieldTy->isVectorType()) + KF_FOR_EACH(handleScalarType, Field, FieldTy); + else + KF_FOR_EACH(handleOtherType, Field, FieldTy); + } + +public: + SubobjectVisitor(ASTContext &C) : Ctx(C) {} + + template + void visitRecordBases(const CXXRecordDecl *KernelFunctor, + HandlerTys &...Handlers) { + visitRecordHelper(KernelFunctor, KernelFunctor->bases(), Handlers...); + } + + template + void visitRecordFields(const CXXRecordDecl *Owner, HandlerTys &...Handlers) { + for (const auto Field : Owner->fields()) + visitField(Owner, Field, Field->getType(), Handlers...); + } + +#undef KF_FOR_EACH +}; + +class SyclKernelFieldHandlerBase { +public: + virtual bool handleStructType(FieldDecl *, QualType) { return true; } + virtual bool handleUnionType(FieldDecl *, QualType) { return true; } + virtual bool handleReferenceType(FieldDecl *, QualType) { return true; } + virtual bool handlePointerType(FieldDecl *, QualType) { return true; } + virtual bool handleArrayType(FieldDecl *, QualType) { return true; } + virtual bool handleScalarType(FieldDecl *, QualType) { return true; } + // Most handlers shouldn't be handling this, just the field checker. + virtual bool handleOtherType(FieldDecl *, QualType) { return true; } + + virtual bool enterStruct(const CXXRecordDecl *, FieldDecl *, QualType) { + return true; + } + virtual bool leaveStruct(const CXXRecordDecl *, FieldDecl *, QualType) { + return true; + } + virtual bool enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, + QualType) { + return true; + } + virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, + QualType) { + return true; + } + // The following are used for stepping through array elements. + virtual bool enterArray(FieldDecl *, QualType, QualType) { return true; } + virtual bool leaveArray(FieldDecl *, QualType, QualType) { return true; } + + virtual ~SyclKernelFieldHandlerBase() = default; +}; + +// A class to act as the direct base for all the SYCL Kernel related +// tasks that contains a reference to Sema (and potentially any other +// universally required data). +class SyclKernelFieldHandler : public SyclKernelFieldHandlerBase { +protected: + SemaSYCL &SemaSYCLRef; + SyclKernelFieldHandler(SemaSYCL &S) : SemaSYCLRef(S) {} +}; + +// A type to check the validity of all of the argument types. +class SyclKernelSpecObjFinder : public SyclKernelFieldHandler { + SourceLocation SrcLoc; + ValueDecl *SKEPFArgObj; + llvm::SmallVector MemberExprBases; + llvm::SmallVectorImpl &ResultingArgs; + + bool isArrayElement(const FieldDecl *FD, QualType Ty) const { + return !SemaSYCLRef.getASTContext().hasSameType(FD->getType(), Ty); + } + +public: + /// Constructor for the SyclKernelFieldChecker + /// \param S The SemaSYCL reference used for diagnostics and context. + /// \param FFLoc Free function location, used to report diagnostics + explicit SyclKernelSpecObjFinder(SemaSYCL &S, SourceLocation Loc, + ValueDecl *SKEPFArgObj, + llvm::SmallVectorImpl &ResultingArgs) + : SyclKernelFieldHandler(S), SrcLoc(Loc), SKEPFArgObj(SKEPFArgObj), + ResultingArgs(ResultingArgs) { + QualType KernelObjT = SKEPFArgObj->getType().getNonReferenceType(); + Expr *Base = SemaSYCLRef.SemaRef.BuildDeclRefExpr(SKEPFArgObj, KernelObjT, + VK_LValue, SrcLoc); + MemberExprBases.push_back(Base); + } + + MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) { + DeclAccessPair MemberDAP = DeclAccessPair::make(Member, AS_none); + MemberExpr *Result = SemaSYCLRef.SemaRef.BuildMemberExpr( + Base, /*IsArrow */ false, SrcLoc, NestedNameSpecifierLoc(), + SrcLoc, Member, MemberDAP, + /*HadMultipleCandidates*/ false, + DeclarationNameInfo(Member->getDeclName(), SrcLoc), + Member->getType(), VK_LValue, OK_Ordinary); + return Result; + } + void addFieldMemberExpr(FieldDecl *FD, QualType Ty) { + if (!isArrayElement(FD, Ty)) + MemberExprBases.push_back(buildMemberExpr(MemberExprBases.back(), FD)); + } + void removeFieldMemberExpr(const FieldDecl *FD, QualType Ty) { + if (!isArrayElement(FD, Ty)) + MemberExprBases.pop_back(); + } + + bool enterStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { + addFieldMemberExpr(FD, Ty); + return true; + } + + bool leaveStruct(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final { + if (isSyclSpecialType(Ty)) { + ResultingArgs.push_back(MemberExprBases.back()); + } + removeFieldMemberExpr(FD, Ty); + return true; + } + + bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS, + QualType) final { + CXXCastPath BasePath; + QualType DerivedTy = SemaSYCLRef.getASTContext().getCanonicalTagType(RD); + QualType BaseTy = BS.getType(); + SemaSYCLRef.SemaRef.CheckDerivedToBaseConversion( + DerivedTy, BaseTy, SrcLoc, SourceRange(), &BasePath, + /*IgnoreBaseAccess*/ true); + auto Cast = ImplicitCastExpr::Create( + SemaSYCLRef.getASTContext(), BaseTy, CK_DerivedToBase, + MemberExprBases.back(), + /* CXXCastPath=*/&BasePath, VK_LValue, FPOptionsOverride()); + MemberExprBases.push_back(Cast); + return true; + } + + bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &, + QualType) final { + MemberExprBases.pop_back(); + return true; + } +}; + +static void createArgumentsForSpecialTypes(SmallVectorImpl &Args, + ValueDecl *KernelArgObj, + SourceLocation Loc, Sema &SemaRef) { + QualType KernelArgObjTy = KernelArgObj->getType(); + const CXXRecordDecl *KernelArgObjRecord = + KernelArgObjTy->getAsCXXRecordDecl(); + assert(KernelArgObjRecord && "SYCL kernel object is expected"); + SyclKernelSpecObjFinder Finder(SemaRef.SYCL(), Loc, KernelArgObj, Args); + SubobjectVisitor Visitor{SemaRef.getASTContext()}; + Visitor.visitRecordBases(KernelArgObjRecord, Finder); + Visitor.visitRecordFields(KernelArgObjRecord, Finder); +} // Constructs the arguments to be passed for the SYCL kernel launch call. // The first argument is a string literal that contains the SYCL kernel // name. The remaining arguments are the parameters of 'FD' passed as // move-elligible xvalues. Returns true on error and false otherwise. -bool BuildSYCLKernelLaunchCallArgs(Sema &SemaRef, FunctionDecl *FD, - const SYCLKernelInfo *SKI, - SmallVectorImpl &Args, - SourceLocation Loc) { +static bool BuildSYCLKernelLaunchCallArgs(Sema &SemaRef, FunctionDecl *FD, + const SYCLKernelInfo *SKI, + SmallVectorImpl &Args, + SourceLocation Loc) { // The current context must be the function definition context to ensure // that parameter references occur within the correct scope. assert(SemaRef.CurContext == FD && "The current declaration context does not " @@ -540,9 +842,9 @@ bool BuildSYCLKernelLaunchCallArgs(Sema &SemaRef, FunctionDecl *FD, } // Constructs the SYCL kernel launch call. -StmtResult BuildSYCLKernelLaunchCallStmt(Sema &SemaRef, FunctionDecl *FD, - const SYCLKernelInfo *SKI, - Expr *IdExpr, SourceLocation Loc) { +StmtResult BuildSYCLKernelLaunchCallStmt( + Sema &SemaRef, FunctionDecl *FD, const SYCLKernelInfo *SKI, Expr *IdExpr, + SourceLocation Loc, SmallVectorImpl &SpecialArgTys) { SmallVector Stmts; // IdExpr may be null if name lookup failed. if (IdExpr) { @@ -574,8 +876,38 @@ StmtResult BuildSYCLKernelLaunchCallStmt(Sema &SemaRef, FunctionDecl *FD, SemaRef.BuildCallExpr(SemaRef.getCurScope(), IdExpr, Loc, Args, Loc); if (LaunchResult.isInvalid()) return StmtError(); - - Stmts.push_back(SemaRef.MaybeCreateExprWithCleanups(LaunchResult).get()); + Expr *BaseForSubsCall = + SemaRef.MaybeCreateExprWithCleanups(LaunchResult).get(); + if (!BaseForSubsCall->getType()->isVoidType()) { + // FIXME: diagnose that sycl_kernel_launch call returned a callable + // object. Default diagnostic here is very uncldear + llvm::SmallVector SpecialArgs; + for (auto Param : FD->parameters()) { + if (Param->getType()->isRecordType()) + createArgumentsForSpecialTypes(SpecialArgs, Param, Loc, SemaRef); + } + ExprResult Result = SemaRef.BuildCallExpr( + SemaRef.getCurScope(), BaseForSubsCall, Loc, SpecialArgs, Loc); + if (Result.isInvalid()) + return StmtError(); + + // Now gather types for device code generation. Callable object returned + // by sycl_kernel_launch call returns type_list object whose template + // arguments describe types of additional kernel arguments required for + // special objects, i.e. SYCL accessors/samplers/streams etc. + QualType Ty = Result.get()->getType(); + // FIXME: that also needs to be diagnosed somewhere. + auto *TST = Ty->getAs(); + if (!TST) + return StmtError(); + for (auto Arg : TST->template_arguments()) { + SpecialArgTys.push_back(Arg.getAsType().getCanonicalType()); + } + + Stmts.push_back(SemaRef.MaybeCreateExprWithCleanups(Result).get()); + } else { + Stmts.push_back(BaseForSubsCall); + } } } @@ -638,14 +970,21 @@ class OutlinedFunctionDeclBodyInstantiator FunctionDecl *FD; }; -OutlinedFunctionDecl *BuildSYCLKernelEntryPointOutline(Sema &SemaRef, - FunctionDecl *FD, - CompoundStmt *Body) { +DeclResult +BuildSYCLKernelEntryPointOutline(Sema &SemaRef, FunctionDecl *FD, + CompoundStmt *Body, + SmallVectorImpl &SpecialArgTys, + Expr *IdExpr, SourceLocation Loc) { using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap; ParmDeclMap ParmMap; OutlinedFunctionDecl *OFD = OutlinedFunctionDecl::Create( - SemaRef.getASTContext(), FD, FD->getNumParams()); + SemaRef.getASTContext(), FD, FD->getNumParams() + SpecialArgTys.size()); + + // CurContext is skep-attributed function but we're actually building device + // version of it which is a different DeclContext, so push it on the stack. + Sema::ContextRAII SavedContext(SemaRef, OFD); + unsigned i = 0; for (ParmVarDecl *PVD : FD->parameters()) { ImplicitParamDecl *IPD = ImplicitParamDecl::Create( @@ -655,21 +994,96 @@ OutlinedFunctionDecl *BuildSYCLKernelEntryPointOutline(Sema &SemaRef, ParmMap[PVD] = IPD; ++i; } - OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap, FD); - Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get(); + Stmt *TransformedBody = OFDBodyInstantiator.TransformStmt(Body).get(); + + // Create kernel parameters for special types and create arguments to + // sycl_handle_special_kernel_parameters call. + // This is synthesizing the following pseudo-code: + // void kernel-entry-point(lambda-from-f kernelFunc, buffer_t* X, int Y) { + // sycl_handle_special_kernel_parameters(kernelFunc.sout)(X, Y); + // { + // // This is copied body of the orignal skep-attributed function. + // kernelFunc(); + // } + // } + // where sout is has type marked with sycl_special_kernel_parameter attribute. + Stmt *OFDBody; + if (IdExpr && !SpecialArgTys.empty()) { + SmallVector HandleArgs; + for (unsigned I = 0; I < FD->getNumParams(); ++I) { + auto Param = OFD->getParam(I); + if (Param->getType()->isRecordType()) + createArgumentsForSpecialTypes(HandleArgs, Param, Loc, SemaRef); + } + + // FIXME add better diagnosing. + // Sema::CodeSynthesisContext CSC; + // CSC.Kind = + // Sema::CodeSynthesisContext::SYCLKernelLaunchOverloadResolution; + // CSC.Entity = FD; + // CSC.CallArgs = Args.data(); + // CSC.NumCallArgs = Args.size(); + // Sema::ScopedCodeSynthesisContext ScopedCSC(SemaRef, CSC); + + // Handle args for sycl_handle_special_kernel_parameters call, these are + // coming from subobjects with sycl_special_kernel_parameter attribute + // within skep-attributed function arguments, SpecialArgs are additional kernel + // arguments that are needed to initialize special subobjects and they go + // to the subsequent call. + SmallVector SpecialArgs; + for (auto QT : SpecialArgTys) { + ImplicitParamDecl *IPD = ImplicitParamDecl::Create( + SemaRef.getASTContext(), OFD, SourceLocation(), + &SemaRef.getASTContext().Idents.get("idk"), QT, + ImplicitParamKind::Other); + OFD->setParam(i, IPD); + ++i; + ExprResult Arg = + SemaRef.BuildDeclRefExpr(IPD, QT, VK_LValue, SourceLocation()); + assert(!Arg.isInvalid() && "synthesized code generation failed?"); + SpecialArgs.push_back(Arg.get()); + } + + // This generates sycl_handle_special_kernel_parameters(kernelFunc.sout) + ExprResult FirstHandleCallResult = SemaRef.BuildCallExpr( + SemaRef.getCurScope(), IdExpr, Loc, HandleArgs, Loc); + if (FirstHandleCallResult.isInvalid()) + return true; + // FIXME: diagnose that sycl_special_kernel_parameter call returned a + // callable object. Default diagnostic here is very uncldear + + Expr *BaseForSubsCall = + SemaRef.MaybeCreateExprWithCleanups(FirstHandleCallResult).get(); + ExprResult Result = SemaRef.BuildCallExpr( + SemaRef.getCurScope(), BaseForSubsCall, Loc, SpecialArgs, Loc); + if (Result.isInvalid()) + return true; + + SmallVector Stmts; + // Make sure to push kernel argument processing result first, before the + // transformed body of skep-attributed function. + Stmts.push_back(SemaRef.MaybeCreateExprWithCleanups(Result).get()); + Stmts.push_back(TransformedBody); + OFDBody = CompoundStmt::Create(SemaRef.getASTContext(), Stmts, + FPOptionsOverride(), Loc, Loc); + } else { + OFDBody = TransformedBody; + } + OFD->setBody(OFDBody); OFD->setNothrow(); - return OFD; + } } // unnamed namespace StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body, - Expr *LaunchIdExpr) { + Expr *LaunchIdExpr, + Expr *HandleSpecParamsExpr) { assert(!FD->isInvalidDecl()); assert(!FD->isTemplated()); assert(FD->hasPrototype()); @@ -690,28 +1104,31 @@ StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, getASTContext().getSYCLKernelInfo(SKEPAttr->getKernelName()); assert(declaresSameEntity(SKI.getKernelEntryPointDecl(), FD) && "SYCL kernel name conflict"); - - // Build the outline of the synthesized device entry point function. - OutlinedFunctionDecl *OFD = - BuildSYCLKernelEntryPointOutline(SemaRef, FD, Body); - assert(OFD); - + SourceLocation Loc = Body->getLBracLoc(); // Build the host kernel launch statement. An appropriate source location // is required to emit diagnostics. - SourceLocation Loc = Body->getLBracLoc(); - StmtResult LaunchResult = - BuildSYCLKernelLaunchCallStmt(SemaRef, FD, &SKI, LaunchIdExpr, Loc); + llvm::SmallVector SpecialArgTys; + StmtResult LaunchResult = BuildSYCLKernelLaunchCallStmt( + SemaRef, FD, &SKI, LaunchIdExpr, Loc, SpecialArgTys); + if (LaunchResult.isInvalid()) return StmtError(); - Stmt *NewBody = - new (getASTContext()) SYCLKernelCallStmt(Body, LaunchResult.get(), OFD); + // Build the outline of the synthesized device entry point function. + DeclResult OFD = BuildSYCLKernelEntryPointOutline( + SemaRef, FD, Body, SpecialArgTys, HandleSpecParamsExpr, Loc); + + if (OFD.isInvalid()) + return StmtError(); + + Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt( + Body, LaunchResult.get(), cast(OFD.get())); return NewBody; } -StmtResult SemaSYCL::BuildUnresolvedSYCLKernelCallStmt(CompoundStmt *Body, - Expr *LaunchIdExpr) { - return UnresolvedSYCLKernelCallStmt::Create(SemaRef.getASTContext(), Body, - LaunchIdExpr); +StmtResult SemaSYCL::BuildUnresolvedSYCLKernelCallStmt( + CompoundStmt *Body, Expr *LaunchIdExpr, Expr *HandleSpecParamsExpr) { + return UnresolvedSYCLKernelCallStmt::Create( + SemaRef.getASTContext(), Body, LaunchIdExpr, HandleSpecParamsExpr); } diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 40187f71231bd..f4a8d7736438d 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -13148,13 +13148,18 @@ StmtResult TreeTransform::TransformUnresolvedSYCLKernelCallStmt( if (IdExpr.isInvalid()) return StmtError(); + ExprResult SpecArgsIdExpr = + getDerived().TransformExpr(S->getSpecArgsIdExpr()); + if (SpecArgsIdExpr.isInvalid()) + return StmtError(); + StmtResult Body = getDerived().TransformStmt(S->getOriginalStmt()); if (Body.isInvalid()) return StmtError(); StmtResult SR = SemaRef.SYCL().BuildSYCLKernelCallStmt( cast(SemaRef.CurContext), cast(Body.get()), - IdExpr.get()); + IdExpr.get(), SpecArgsIdExpr.get()); if (SR.isInvalid()) return StmtError(); diff --git a/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp b/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp index 1047302d8f36e..5a31981a85ed4 100644 --- a/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp +++ b/clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp @@ -37,6 +37,9 @@ template struct K { template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + [[clang::sycl_kernel_entry_point(KN<1>)]] void skep1() { } diff --git a/clang/test/ASTSYCL/ast-dump-sycl-kernel-decomposition.cpp b/clang/test/ASTSYCL/ast-dump-sycl-kernel-decomposition.cpp new file mode 100644 index 0000000000000..67e7b0266dc32 --- /dev/null +++ b/clang/test/ASTSYCL/ast-dump-sycl-kernel-decomposition.cpp @@ -0,0 +1,141 @@ +// Tests without serialization: +// RUN: %clang_cc1 -std=c++17 -triple spirv64-unknown-unknown -fsycl-is-device \ +// RUN: -ast-dump %s \ +// RUN: | FileCheck %s +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-unknown -fsycl-is-host \ +// RUN: -ast-dump %s \ +// RUN: | FileCheck %s + +// Thes test validates the AST body produced for functions declared with the +// sycl_kernel_entry_point attribute in case an argument of such function +// contains an object that requires decomposition. + +// CHECK: TranslationUnitDecl {{.*}} + +// A unique kernel name type is required for each declared kernel entry point. +template struct KN; + +struct [[clang::sycl_special_kernel_parameter]] EmptySpecial { + int data; +}; + +template +struct Wrapper { + T data; + int *data1; +}; + +template +auto set_kernel_arg(const T &t) { + return t; +} + +auto set_kernel_arg(EmptySpecial &a) { + return a.data; +} + +template +auto sycl_handle_special_kernel_parameters(Ts...) { + return [](auto ...Args){ return; }; +} + +template +struct type_list {}; + +template +auto sycl_kernel_launch(const char *, Ts...) { + + return [&](auto&&... extra_host_args) { + return type_list{}; + }; +} + + +template +[[clang::sycl_kernel_entry_point(KN)]] void k(KT Kernel) { + Kernel(); +} +// CHECK: |-FunctionTemplateDecl {{.*}} k{{.*}} +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} referenced typename depth 0 index 0 KN +// CHECK-NEXT: | |-TemplateTypeParmDecl {{.*}} referenced typename depth 0 index 1 KT +// CHECK-NEXT: | |-FunctionDecl {{.*}} k 'void (KT)' +// CHECK-NEXT: | | |-ParmVarDecl {{.*}} referenced Kernel 'KT' +// CHECK-NEXT: | | |-UnresolvedSYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | | `-CallExpr {{.*}} '' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'KT' lvalue ParmVar {{.*}} 'Kernel' 'KT' +// CHECK-NEXT: | | | |-UnresolvedLookupExpr {{.*}} '' lvalue (ADL) = 'sycl_kernel_launch' {{.*}} +// CHECK-NEXT: | | | | `-TemplateArgument type 'KN':'type-parameter-0-0' +// CHECK-NEXT: | | | | `-TemplateTypeParmType {{.*}} 'KN' dependent depth 0 index 0 +// CHECK-NEXT: | | | | `-TemplateTypeParm {{.*}} 'KN' +// CHECK-NEXT: | | | `-UnresolvedLookupExpr {{.*}} '' lvalue (ADL) = 'sycl_handle_special_kernel_parameters' {{.*}} +// CHECK-NEXT: | | | `-TemplateArgument type 'KN':'type-parameter-0-0' +// CHECK-NEXT: | | | `-TemplateTypeParmType {{.*}} 'KN' dependent depth 0 index 0 +// CHECK-NEXT: | | | `-TemplateTypeParm {{.*}} 'KN' +// CHECK-NEXT: | | `-SYCLKernelEntryPointAttr {{.*}} KN +// CHECK-NEXT: | `-FunctionDecl {{.*}} used k {{.*}} implicit_instantiation instantiated_from {{.*}} +// CHECK-NEXT: | |-TemplateArgument type 'KN<0>' +// CHECK-NEXT: | | `-RecordType {{.*}} 'KN<0>' canonical +// CHECK-NEXT: | | `-ClassTemplateSpecialization {{.*}} 'KN' +// CHECK-NEXT: | |-TemplateArgument type '{{.*}}' +// CHECK-NEXT: | | `-RecordType {{.*}} canonical +// CHECK-NEXT: | | `-CXXRecord {{.*}} +// CHECK-NEXT: | |-ParmVarDecl {{.*}} used Kernel {{.*}} +// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const {{.*}}' lvalue +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} lvalue ParmVar {{.*}} 'Kernel' {{.*}} +// CHECK-NEXT: | | |-CompoundStmt {{.*}} +// CHECK-NEXT: | | | `-ExprWithCleanups {{.*}} 'type_list<{{.*}}>' +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'type_list<{{.*}}>' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'type_list<{{.*}}> (*)(EmptySpecial &) const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'type_list<{{.*}}> (EmptySpecial &) const' lvalue CXXMethod {{.*}} 'operator()' '{{.*}}' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'const {{.*}}' lvalue +// CHECK-NEXT: | | | | `-MaterializeTemporaryExpr {{.*}} '{{.*}}' lvalue +// CHECK-NEXT: | | | | `-CallExpr {{.*}} '{{.*}}' +// CHECK-NEXT: | | | | |-ImplicitCastExpr {{.*}} '{{.*}}' +// CHECK-NEXT: | | | | | `-DeclRefExpr {{.*}} '{{.*}}' lvalue Function {{.*}} 'sycl_kernel_launch' {{.*}} +// CHECK-NEXT: | | | | |-ImplicitCastExpr {{.*}} 'const char *' +// CHECK-NEXT: | | | | | `-StringLiteral {{.*}} 'const char[14]' lvalue "_ZTS2KNILi0EE" +// CHECK-NEXT: | | | | `-CXXConstructExpr {{.*}} '{{.*}}' 'void ({{.*}} &&) noexcept' +// CHECK-NEXT: | | | | `-ImplicitCastExpr {{.*}} '{{.*}}' xvalue +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} lvalue ParmVar {{.*}} 'Kernel' {{.*}} +// CHECK-NEXT: | | | `-MemberExpr {{.*}} 'EmptySpecial' lvalue .data {{.*}} +// CHECK-NEXT: | | | `-MemberExpr {{.*}} 'Wrapper' lvalue . {{.*}} +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} lvalue ParmVar {{.*}} 'Kernel' {{.*}} +// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used Kernel {{.*}} +// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used idk {{.*}} +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | |-ExprWithCleanups {{.*}} 'void' +// CHECK-NEXT: | | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'void (*)(int) const' +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} 'void (int) const' lvalue CXXMethod {{.*}} 'operator()' '{{.*}}' +// CHECK-NEXT: | | | |-ImplicitCastExpr {{.*}} 'const {{.*}}' lvalue +// CHECK-NEXT: | | | | `-MaterializeTemporaryExpr {{.*}} '{{.*}}' lvalue +// CHECK-NEXT: | | | | `-CallExpr {{.*}} '{{.*}}' +// CHECK-NEXT: | | | | |-ImplicitCastExpr {{.*}} '{{.*}}' +// CHECK-NEXT: | | | | | `-DeclRefExpr {{.*}} '{{.*}}' lvalue Function {{.*}} 'sycl_handle_special_kernel_parameters' {{.*}} +// CHECK-NEXT: | | | | `-CXXConstructExpr {{.*}} 'EmptySpecial' 'void (const EmptySpecial &) noexcept' +// CHECK-NEXT: | | | | `-ImplicitCastExpr {{.*}} 'const EmptySpecial' lvalue +// CHECK-NEXT: | | | | `-MemberExpr {{.*}} 'EmptySpecial' lvalue .data {{.*}} +// CHECK-NEXT: | | | | `-MemberExpr {{.*}} 'Wrapper' lvalue . {{.*}} +// CHECK-NEXT: | | | | `-DeclRefExpr {{.*}} lvalue ImplicitParam {{.*}} 'Kernel' {{.*}} +// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'int' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'idk' 'int' +// CHECK-NEXT: | | `-CompoundStmt {{.*}} +// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()' +// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' +// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const' +// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const {{.*}}' lvalue +// CHECK-NEXT: | | `-DeclRefExpr {{.*}} lvalue ImplicitParam {{.*}} 'Kernel' {{.*}} +// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} struct KN<0> + +void case1() { + Wrapper KernelArg; + k>([KernelArg](){}); +} +// CHECK: `-FunctionDecl {{.*}} case1 'void ()' diff --git a/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp b/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp index cf7db46b3567e..b8df861fb5e1e 100644 --- a/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp +++ b/clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp @@ -31,6 +31,9 @@ template struct KN; template void sycl_kernel_launch(const char *, Ts... Args) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + [[clang::sycl_kernel_entry_point(KN<1>)]] void skep1() { } diff --git a/clang/test/ASTSYCL/ast-print-sycl-kernel-call.cpp b/clang/test/ASTSYCL/ast-print-sycl-kernel-call.cpp index 5adaa367ed9c1..f32b931205042 100644 --- a/clang/test/ASTSYCL/ast-print-sycl-kernel-call.cpp +++ b/clang/test/ASTSYCL/ast-print-sycl-kernel-call.cpp @@ -1,6 +1,9 @@ // RUN: %clang_cc1 -fsycl-is-host -ast-print %s -o - | FileCheck %s // RUN: %clang_cc1 -fsycl-is-device -ast-print %s -o - | FileCheck %s +template +void sycl_handle_special_kernel_parameters(Ts...) {} + struct sycl_kernel_launcher { template void sycl_kernel_launch(const char *, Ts...) {} diff --git a/clang/test/CodeGenSYCL/function-attrs.cpp b/clang/test/CodeGenSYCL/function-attrs.cpp index 60d3cf10055ec..efb0081484522 100644 --- a/clang/test/CodeGenSYCL/function-attrs.cpp +++ b/clang/test/CodeGenSYCL/function-attrs.cpp @@ -29,6 +29,9 @@ int foo() { template void sycl_kernel_launch(Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + template [[clang::sycl_kernel_entry_point(Name)]] void kernel_single_task(const Func &kernelFunc) { kernelFunc(); diff --git a/clang/test/CodeGenSYCL/kernel-arg-decomposition.cpp b/clang/test/CodeGenSYCL/kernel-arg-decomposition.cpp new file mode 100644 index 0000000000000..b7aa38cc9a1cb --- /dev/null +++ b/clang/test/CodeGenSYCL/kernel-arg-decomposition.cpp @@ -0,0 +1,96 @@ +// RUN: %clang_cc1 -fsycl-is-host -emit-llvm -triple x86_64-unknown-linux-gnu -std=c++17 %s -o - | FileCheck --check-prefixes=CHECK-HOST %s +// RUN: %clang_cc1 -fsycl-is-device -emit-llvm -aux-triple x86_64-unknown-linux-gnu -triple spirv64-unknown-unknown -std=c++17 %s -o - | FileCheck --check-prefixes=CHECK-DEVICE %s + +// A unique kernel name type is required for each declared kernel entry point. +template struct KN; + +struct [[clang::sycl_special_kernel_parameter]] EmptySpecial { + int data; +}; + +template +struct Wrapper { + T data; + int *data1; +}; + +template +auto set_kernel_arg(const T &t) { + return t; +} + +auto set_kernel_arg(EmptySpecial &a) { + return a.data; +} + +template +auto sycl_handle_special_kernel_parameters(Ts...) { + return [](auto ...Args){ return; }; +} + +template +struct type_list {}; + +template +auto sycl_kernel_launch(const char *, Ts...) { + + return [&](auto&&... extra_host_args) { + return type_list{}; + }; +} + + +template +[[clang::sycl_kernel_entry_point(KN)]] void kernel_entry_point(KT Kernel) { + Kernel(); +} + +void case1() { + Wrapper KernelArg; + kernel_entry_point>([KernelArg](){}); +} + +// CHECK-HOST-LABEL: define internal void @_Z18kernel_entry_pointI2KNILi0EEZ5case1vEUlvE_EvT0_( +// CHECK-HOST-SAME: i32 [[KERNEL_COERCE0:%.*]], ptr [[KERNEL_COERCE1:%.*]]) +// CHECK-HOST: [[ENTRY:.*:]] +// CHECK-HOST-NEXT: [[KERNEL:%.*]] = alloca [[CLASS_ANON:%.*]], align 8 +// CHECK-HOST-NEXT: [[REF_TMP:%.*]] = alloca [[CLASS_ANON_0:%.*]], align 1 +// CHECK-HOST-NEXT: [[AGG_TMP:%.*]] = alloca [[CLASS_ANON]], align 8 +// CHECK-HOST-NEXT: [[UNDEF_AGG_TMP:%.*]] = alloca [[CLASS_ANON_0]], align 1 +// CHECK-HOST-NEXT: [[UNDEF_AGG_TMP1:%.*]] = alloca [[STRUCT_TYPE_LIST:%.*]], align 1 +// CHECK-HOST-NEXT: [[TMP0:%.*]] = getelementptr inbounds nuw { i32, ptr }, ptr [[KERNEL]], i32 0, i32 0 +// CHECK-HOST-NEXT: store i32 [[KERNEL_COERCE0]], ptr [[TMP0]], align 8 +// CHECK-HOST-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw { i32, ptr }, ptr [[KERNEL]], i32 0, i32 1 +// CHECK-HOST-NEXT: store ptr [[KERNEL_COERCE1]], ptr [[TMP1]], align 8 +// CHECK-HOST-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[AGG_TMP]], ptr align 8 [[KERNEL]], i64 16, i1 false) +// CHECK-HOST-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw { i32, ptr }, ptr [[AGG_TMP]], i32 0, i32 0 +// CHECK-HOST-NEXT: [[TMP3:%.*]] = load i32, ptr [[TMP2]], align 8 +// CHECK-HOST-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw { i32, ptr }, ptr [[AGG_TMP]], i32 0, i32 1 +// CHECK-HOST-NEXT: [[TMP5:%.*]] = load ptr, ptr [[TMP4]], align 8 +// CHECK-HOST-NEXT: call void @_Z18sycl_kernel_launchI2KNILi0EEJZ5case1vEUlvE_EEDaPKcDpT0_(ptr noundef @.str, i32 [[TMP3]], ptr [[TMP5]]) +// CHECK-HOST-NEXT: [[TMP6:%.*]] = getelementptr inbounds nuw [[CLASS_ANON]], ptr [[KERNEL]], i32 0, i32 0 +// CHECK-HOST-NEXT: [[DATA:%.*]] = getelementptr inbounds nuw [[STRUCT_WRAPPER:%.*]], ptr [[TMP6]], i32 0, i32 0 +// CHECK-HOST-NEXT: call void @_ZZ18sycl_kernel_launchI2KNILi0EEJZ5case1vEUlvE_EEDaPKcDpT0_ENKUlDpOT_E_clIJR12EmptySpecialEEEDaS9_(ptr noundef nonnull align 1 dereferenceable(1) [[REF_TMP]], ptr noundef nonnull align 4 dereferenceable(4) [[DATA]]) +// CHECK-HOST-NEXT: ret void + +// CHECK-DEVICE: define spir_kernel void @_ZTS2KNILi0EE(ptr noundef byval(%class.anon) align 8 [[KERNEL:%.*]], i32 noundef [[IDK:%.*]]) +// CHECK-DEVICE: [[ENTRY:.*:]] +// CHECK-DEVICE: [[IDK_ADDR:%.*]] = alloca i32, align 4 +// CHECK-DEVICE: [[REF_TMP:%.*]] = alloca [[CLASS_ANON_0:%.*]], align 1 +// CHECK-DEVICE: [[AGG_TMP:%.*]] = alloca [[STRUCT_EMPTYSPECIAL:%.*]], align 4 +// CHECK-DEVICE: [[IDK_ADDR_ASCAST:%.*]] = addrspacecast ptr [[IDK_ADDR]] to ptr addrspace(4) +// CHECK-DEVICE: [[REF_TMP_ASCAST:%.*]] = addrspacecast ptr [[REF_TMP]] to ptr addrspace(4) +// CHECK-DEVICE: [[AGG_TMP_ASCAST:%.*]] = addrspacecast ptr [[AGG_TMP]] to ptr addrspace(4) +// CHECK-DEVICE: [[KERNEL_ASCAST:%.*]] = addrspacecast ptr [[KERNEL]] to ptr addrspace(4) +// CHECK-DEVICE: store i32 [[IDK]], ptr addrspace(4) [[IDK_ADDR_ASCAST]], align 4 +// CHECK-DEVICE: [[REF_TMP_ASCAST_ASCAST:%.*]] = addrspacecast ptr addrspace(4) [[REF_TMP_ASCAST]] to ptr +// CHECK-DEVICE: [[TMP0:%.*]] = getelementptr inbounds nuw [[CLASS_ANON:%.*]], ptr addrspace(4) [[KERNEL_ASCAST]], i32 0, i32 0 +// CHECK-DEVICE: [[DATA:%.*]] = getelementptr inbounds nuw [[STRUCT_WRAPPER:%.*]], ptr addrspace(4) [[TMP0]], i32 0, i32 0 +// CHECK-DEVICE: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 4 [[AGG_TMP_ASCAST]], ptr addrspace(4) align 8 [[DATA]], i64 4, i1 false) +// CHECK-DEVICE: [[AGG_TMP_ASCAST_ASCAST:%.*]] = addrspacecast ptr addrspace(4) [[AGG_TMP_ASCAST]] to ptr +// CHECK-DEVICE: call spir_func void @_Z37sycl_handle_special_kernel_parametersI2KNILi0EEJ12EmptySpecialEEDaDpT0_(ptr dead_on_unwind writable sret([[CLASS_ANON_0]]) align 1 [[REF_TMP_ASCAST_ASCAST]], ptr noundef byval([[STRUCT_EMPTYSPECIAL]]) align 4 [[AGG_TMP_ASCAST_ASCAST]]) +// CHECK-DEVICE: [[TMP1:%.*]] = load i32, ptr addrspace(4) [[IDK_ADDR_ASCAST]], align 4 +// CHECK-DEVICE: call spir_func void @_ZZ37sycl_handle_special_kernel_parametersI2KNILi0EEJ12EmptySpecialEEDaDpT0_ENKUlDpT_E_clIJiEEEDaS6_(ptr addrspace(4) noundef align 1 dereferenceable_or_null(1) [[REF_TMP_ASCAST]], i32 noundef [[TMP1]]) +// CHECK-DEVICE: call spir_func void @_ZZ5case1vENKUlvE_clEv(ptr addrspace(4) noundef align 8 dereferenceable_or_null(16) [[KERNEL_ASCAST]]) +// CHECK-DEVICE: ret void + diff --git a/clang/test/CodeGenSYCL/kernel-caller-entry-point.cpp b/clang/test/CodeGenSYCL/kernel-caller-entry-point.cpp index 7af4c83d1ba32..edb58ab57bdbd 100644 --- a/clang/test/CodeGenSYCL/kernel-caller-entry-point.cpp +++ b/clang/test/CodeGenSYCL/kernel-caller-entry-point.cpp @@ -33,6 +33,9 @@ template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + struct single_purpose_kernel_name; struct single_purpose_kernel { void operator()() const {} diff --git a/clang/test/CodeGenSYCL/sycl-kernel-entry-point-exceptions.cpp b/clang/test/CodeGenSYCL/sycl-kernel-entry-point-exceptions.cpp index 8fe7a148a2f61..56137b934eeca 100644 --- a/clang/test/CodeGenSYCL/sycl-kernel-entry-point-exceptions.cpp +++ b/clang/test/CodeGenSYCL/sycl-kernel-entry-point-exceptions.cpp @@ -16,6 +16,9 @@ struct KT { }; +template +void sycl_handle_special_kernel_parameters(Ts...) {} + // Validate that exception handling instructions are omitted when a // potentially throwing sycl_kernel_entry_point attributed function // calls a potentially throwing sycl_kernel_launch function (a thrown diff --git a/clang/test/CodeGenSYCL/unique_stable_name_windows_diff.cpp b/clang/test/CodeGenSYCL/unique_stable_name_windows_diff.cpp index c298593e2f1ab..be4fa5482552a 100644 --- a/clang/test/CodeGenSYCL/unique_stable_name_windows_diff.cpp +++ b/clang/test/CodeGenSYCL/unique_stable_name_windows_diff.cpp @@ -4,6 +4,9 @@ template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + template [[clang::sycl_kernel_entry_point(KN)]] void kernel(Func F){ F(); diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp index 45da8c71348b2..49a6de671a8be 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp @@ -44,6 +44,9 @@ template struct KN; template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + //////////////////////////////////////////////////////////////////////////////// // Valid declarations. //////////////////////////////////////////////////////////////////////////////// diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-device-odr-use.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-device-odr-use.cpp index e2854983da552..94a497f6a9cd3 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-device-odr-use.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-device-odr-use.cpp @@ -20,6 +20,9 @@ struct type_info { template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + // A kernel name type template. template struct KN; diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-grammar.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-grammar.cpp index b1c9e270a02b8..a8cacbadc009b 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-grammar.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-grammar.cpp @@ -14,6 +14,9 @@ template using TTA = ST; // #TTA-decl template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + //////////////////////////////////////////////////////////////////////////////// // Valid declarations. //////////////////////////////////////////////////////////////////////////////// diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-module.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-module.cpp index 05a660e91e82c..a0cbbfad420d7 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-module.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-module.cpp @@ -21,6 +21,9 @@ template struct KN; template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + [[clang::sycl_kernel_entry_point(KN<1>)]] void common_test1() {} diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-pch.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-pch.cpp index dcea60e016d12..047c6268cfe6c 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-pch.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name-pch.cpp @@ -19,6 +19,9 @@ template struct KN; template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + [[clang::sycl_kernel_entry_point(KN<1>)]] void pch_test1() {} // << expected previous declaration note here. diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name.cpp index 2abb24cde6663..e8fbd422d4c61 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-kernel-name.cpp @@ -14,6 +14,9 @@ struct S1; template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + // expected-warning@+3 {{redundant 'clang::sycl_kernel_entry_point' attribute}} // expected-note@+1 {{previous attribute is here}} [[clang::sycl_kernel_entry_point(S1), diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-sfinae.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-sfinae.cpp index b39a77bd35878..f4fca91ba1178 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-sfinae.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-sfinae.cpp @@ -14,6 +14,9 @@ template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + // FIXME: C++23 [temp.expl.spec]p12 states: // FIXME: ... Similarly, attributes appearing in the declaration of a template // FIXME: have no effect on an explicit specialization of that template. diff --git a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-this.cpp b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-this.cpp index 2112733b41fc6..7bc4831fcad1d 100644 --- a/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-this.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-entry-point-attr-this.cpp @@ -26,6 +26,9 @@ struct type_info { template void sycl_kernel_launch(const char *, Ts...) {} +template +void sycl_handle_special_kernel_parameters(Ts...) {} + //////////////////////////////////////////////////////////////////////////////// // Valid declarations. //////////////////////////////////////////////////////////////////////////////// diff --git a/clang/test/SemaSYCL/sycl-kernel-launch-ms-compat.cpp b/clang/test/SemaSYCL/sycl-kernel-launch-ms-compat.cpp index cd186a833b024..612ec1b68ecfb 100644 --- a/clang/test/SemaSYCL/sycl-kernel-launch-ms-compat.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-launch-ms-compat.cpp @@ -19,6 +19,9 @@ struct KT { }; +template +void sycl_handle_special_kernel_parameters(Ts...) {} + namespace ok1 { template struct base_handler { diff --git a/clang/test/SemaSYCL/sycl-kernel-launch.cpp b/clang/test/SemaSYCL/sycl-kernel-launch.cpp index 20d9becb81929..9b8b2696d7b3f 100644 --- a/clang/test/SemaSYCL/sycl-kernel-launch.cpp +++ b/clang/test/SemaSYCL/sycl-kernel-launch.cpp @@ -24,6 +24,9 @@ struct KT { }; +template +void sycl_handle_special_kernel_parameters(Ts...) {} + // sycl_kernel_launch as function template at namespace scope. namespace ok1 { template