@@ -3731,18 +3731,23 @@ enum class AbstractFunctionDeclLookupErrorKind {
37313731 CandidateNotFunctionDeclaration
37323732};
37333733
3734- // / Returns the function declaration corresponding to the given base type
3735- // / (optional), function name, and lookup context.
3734+ // / Returns the original function (in the context of a derivative or transpose
3735+ // / function) declaration corresponding to the given base type (optional),
3736+ // / function name, lookup context, and the expected original function type.
37363737// /
37373738// / If the base type of the function is specified, member lookup is performed.
37383739// / Otherwise, unqualified lookup is performed.
37393740// /
3741+ // / If the expected original function type has a generic signature, any
3742+ // / candidate with a less constrained type signature than the expected original
3743+ // / function type will be treated as a viable candidate.
3744+ // /
37403745// / If the function declaration cannot be resolved, emits a diagnostic and
37413746// / returns nullptr.
37423747// /
37433748// / Used for resolving the referenced declaration in `@derivative` and
37443749// / `@transpose` attributes.
3745- static AbstractFunctionDecl *findAbstractFunctionDecl (
3750+ static AbstractFunctionDecl *findAutoDiffOriginalFunctionDecl (
37463751 DeclAttribute *attr, Type baseType, DeclNameRefWithLoc funcNameWithLoc,
37473752 DeclContext *lookupContext, NameLookupOptions lookupOptions,
37483753 const llvm::function_ref<Optional<AbstractFunctionDeclLookupErrorKind>(
@@ -4671,7 +4676,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46714676 }
46724677
46734678 // Look up original function.
4674- auto *originalAFD = findAbstractFunctionDecl (
4679+ auto *originalAFD = findAutoDiffOriginalFunctionDecl (
46754680 attr, baseType, originalName, derivativeTypeCtx, lookupOptions,
46764681 isValidOriginalCandidate, originalFnType);
46774682 if (!originalAFD) {
@@ -5230,7 +5235,7 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
52305235 auto funcLoc = originalName.Loc .getBaseNameLoc ();
52315236 if (attr->getBaseTypeRepr ())
52325237 funcLoc = attr->getBaseTypeRepr ()->getLoc ();
5233- auto *originalAFD = findAbstractFunctionDecl (
5238+ auto *originalAFD = findAutoDiffOriginalFunctionDecl (
52345239 attr, baseType, originalName, transposeTypeCtx, lookupOptions,
52355240 isValidOriginalCandidate, expectedOriginalFnType);
52365241 if (!originalAFD) {
0 commit comments