3030#include " swift/AST/ParameterList.h"
3131#include " swift/AST/PropertyWrappers.h"
3232#include " swift/AST/SourceFile.h"
33+ #include " swift/AST/StorageImpl.h"
3334#include " swift/AST/TypeCheckRequests.h"
3435#include " swift/AST/Types.h"
3536#include " swift/Parse/Lexer.h"
@@ -3609,12 +3610,14 @@ static IndexSubset *computeDifferentiabilityParameters(
36093610// If the function declaration cannot be resolved, emits a diagnostic and
36103611// returns nullptr.
36113612static AbstractFunctionDecl *findAbstractFunctionDecl (
3612- DeclNameRef funcName, SourceLoc funcNameLoc, Type baseType,
3613+ DeclNameRef funcName, SourceLoc funcNameLoc,
3614+ Optional<AccessorKind> accessorKind, Type baseType,
36133615 DeclContext *lookupContext,
36143616 const std::function<bool (AbstractFunctionDecl *)> &isValidCandidate,
36153617 const std::function<void()> &noneValidDiagnostic,
36163618 const std::function<void()> &ambiguousDiagnostic,
36173619 const std::function<void()> ¬FunctionDiagnostic,
3620+ const std::function<void()> &missingAccessorDiagnostic,
36183621 NameLookupOptions lookupOptions,
36193622 const Optional<std::function<bool(AbstractFunctionDecl *)>>
36203623 &hasValidTypeCtx,
@@ -3640,6 +3643,7 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36403643 bool wrongTypeContext = false ;
36413644 bool ambiguousFuncDecl = false ;
36423645 bool foundInvalid = false ;
3646+ bool missingAccessor = false ;
36433647
36443648 // Filter lookup results.
36453649 for (auto choice : results) {
@@ -3648,10 +3652,21 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36483652 continue ;
36493653 // Cast the candidate to an `AbstractFunctionDecl`.
36503654 auto *candidate = dyn_cast<AbstractFunctionDecl>(decl);
3651- // If the candidate is an `AbstractStorageDecl`, use its getter as the
3652- // candidate.
3653- if (auto *asd = dyn_cast<AbstractStorageDecl>(decl))
3654- candidate = asd->getOpaqueAccessor (AccessorKind::Get);
3655+ // If the candidate is an `AbstractStorageDecl`, use one of its accessors as
3656+ // the candidate.
3657+ if (auto *asd = dyn_cast<AbstractStorageDecl>(decl)) {
3658+ // If accessor kind is specified, use corresponding accessor from the
3659+ // candidate. Otherwise, use the getter by default.
3660+ if (accessorKind != None) {
3661+ candidate = asd->getOpaqueAccessor (accessorKind.getValue ());
3662+ // Error if candidate is missing the requested accessor.
3663+ if (!candidate)
3664+ missingAccessor = true ;
3665+ } else
3666+ candidate = asd->getOpaqueAccessor (AccessorKind::Get);
3667+ } else if (accessorKind != None) {
3668+ missingAccessor = true ;
3669+ }
36553670 if (!candidate) {
36563671 notFunction = true ;
36573672 continue ;
@@ -3671,8 +3686,9 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36713686 }
36723687 resolvedCandidate = candidate;
36733688 }
3689+
36743690 // If function declaration was resolved, return it.
3675- if (resolvedCandidate)
3691+ if (resolvedCandidate && !missingAccessor )
36763692 return resolvedCandidate;
36773693
36783694 // Otherwise, emit the appropriate diagnostic and return nullptr.
@@ -3685,6 +3701,10 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
36853701 ambiguousDiagnostic ();
36863702 return nullptr ;
36873703 }
3704+ if (missingAccessor) {
3705+ missingAccessorDiagnostic ();
3706+ return nullptr ;
3707+ }
36883708 if (wrongTypeContext) {
36893709 assert (invalidTypeCtxDiagnostic &&
36903710 " Type context diagnostic should've been specified" );
@@ -4429,6 +4449,13 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44294449 diag::autodiff_attr_original_decl_invalid_kind,
44304450 originalName.Name );
44314451 };
4452+ auto missingAccessorDiagnostic = [&]() {
4453+ auto accessorKind = originalName.AccessorKind .getValueOr (AccessorKind::Get);
4454+ auto accessorLabel = getAccessorLabel (accessorKind);
4455+ diags.diagnose (originalName.Loc , diag::autodiff_attr_accessor_not_found,
4456+ originalName.Name , accessorLabel);
4457+ };
4458+
44324459 std::function<void ()> invalidTypeContextDiagnostic = [&]() {
44334460 diags.diagnose (originalName.Loc ,
44344461 diag::autodiff_attr_original_decl_not_same_type_context,
@@ -4473,15 +4500,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44734500
44744501 // Look up original function.
44754502 auto *originalAFD = findAbstractFunctionDecl (
4476- originalName.Name , originalName.Loc .getBaseNameLoc (), baseType,
4477- derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
4478- ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
4479- hasValidTypeContext, invalidTypeContextDiagnostic);
4503+ originalName.Name , originalName.Loc .getBaseNameLoc (),
4504+ originalName.AccessorKind , baseType, derivativeTypeCtx, isValidOriginal,
4505+ noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
4506+ missingAccessorDiagnostic, lookupOptions, hasValidTypeContext,
4507+ invalidTypeContextDiagnostic);
44804508 if (!originalAFD)
44814509 return true ;
4482- // Diagnose original stored properties. Stored properties cannot have custom
4483- // registered derivatives.
4510+
44844511 if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
4512+ // Diagnose original stored properties. Stored properties cannot have custom
4513+ // registered derivatives.
44854514 auto *asd = accessorDecl->getStorage ();
44864515 if (asd->hasStorage ()) {
44874516 diags.diagnose (originalName.Loc ,
@@ -4491,6 +4520,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44914520 asd->getName ());
44924521 return true ;
44934522 }
4523+ // Diagnose original class property and subscript setters.
4524+ // TODO(SR-13096): Fix derivative function typing results regarding
4525+ // class-typed function parameters.
4526+ if (asd->getDeclContext ()->getSelfClassDecl () &&
4527+ accessorDecl->getAccessorKind () == AccessorKind::Set) {
4528+ diags.diagnose (originalName.Loc ,
4529+ diag::derivative_attr_class_setter_unsupported);
4530+ diags.diagnose (originalAFD->getLoc (), diag::decl_declared_here,
4531+ asd->getName ());
4532+ return true ;
4533+ }
44944534 }
44954535 // Diagnose if original function is an invalid class member.
44964536 bool isOriginalClassMember =
@@ -4998,6 +5038,13 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
49985038 diag::autodiff_attr_original_decl_invalid_kind,
49995039 originalName.Name );
50005040 };
5041+ auto missingAccessorDiagnostic = [&]() {
5042+ auto accessorKind = originalName.AccessorKind .getValueOr (AccessorKind::Get);
5043+ auto accessorLabel = getAccessorLabel (accessorKind);
5044+ diagnose (originalName.Loc , diag::autodiff_attr_accessor_not_found,
5045+ originalName.Name , accessorLabel);
5046+ };
5047+
50015048 std::function<void ()> invalidTypeContextDiagnostic = [&]() {
50025049 diagnose (originalName.Loc ,
50035050 diag::autodiff_attr_original_decl_not_same_type_context,
@@ -5028,8 +5075,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
50285075 if (attr->getBaseTypeRepr ())
50295076 funcLoc = attr->getBaseTypeRepr ()->getLoc ();
50305077 auto *originalAFD = findAbstractFunctionDecl (
5031- originalName.Name , funcLoc, baseType, transposeTypeCtx, isValidOriginal,
5032- noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
5078+ originalName.Name , funcLoc, originalName.AccessorKind , baseType,
5079+ transposeTypeCtx, isValidOriginal, noneValidDiagnostic,
5080+ ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
50335081 lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
50345082 if (!originalAFD) {
50355083 attr->setInvalid ();
0 commit comments