2020#include " swift/AST/ASTContext.h"
2121#include " swift/AST/ASTWalker.h"
2222#include " swift/AST/ASTMangler.h"
23+ #include " swift/AST/Attr.h"
2324#include " swift/AST/CaptureInfo.h"
2425#include " swift/AST/DiagnosticEngine.h"
2526#include " swift/AST/DiagnosticsSema.h"
@@ -8311,7 +8312,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
83118312}
83128313
83138314ArrayRef<AutoDiffConfig>
8314- AbstractFunctionDecl::getDerivativeFunctionConfigurations () {
8315+ AbstractFunctionDecl::getDerivativeFunctionConfigurations (bool lookInNonPrimarySources ) {
83158316 prepareDerivativeFunctionConfigurations ();
83168317
83178318 // Resolve derivative function configurations from `@differentiable`
@@ -8334,6 +8335,37 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
83348335 ctx.loadDerivativeFunctionConfigurations (this , previousGeneration,
83358336 *DerivativeFunctionConfigs);
83368337 }
8338+
8339+ class DerivativeFinder : public ASTWalker {
8340+ const AbstractFunctionDecl *AFD;
8341+ public:
8342+ DerivativeFinder (const AbstractFunctionDecl *afd) : AFD(afd) {}
8343+
8344+ bool walkToDeclPre (Decl *D) override {
8345+ if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
8346+ for (auto *derAttr : afd->getAttrs ().getAttributes <DerivativeAttr>()) {
8347+ // Resolve derivative function configurations from `@derivative`
8348+ // attributes by type-checking them.
8349+ if (AFD->getName ().matchesRef (
8350+ derAttr->getOriginalFunctionName ().Name .getFullName ())) {
8351+ (void )derAttr->getOriginalFunction (afd->getASTContext ());
8352+ return false ;
8353+ }
8354+ }
8355+ }
8356+
8357+ return true ;
8358+ }
8359+ };
8360+
8361+ // Load derivative configurations from @derivative attributes defined in
8362+ // non-primary sources. Note that it might trigger lookup cycles if called
8363+ // from inside Sema stages.
8364+ if (lookInNonPrimarySources) {
8365+ DerivativeFinder finder (this );
8366+ getParent ()->walkContext (finder);
8367+ }
8368+
83378369 return DerivativeFunctionConfigs->getArrayRef ();
83388370}
83398371
0 commit comments