4747#include " swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
4848#include " llvm/ADT/APSInt.h"
4949#include " llvm/ADT/BreadthFirstIterator.h"
50+ #include " llvm/ADT/DenseMap.h"
5051#include " llvm/ADT/DenseSet.h"
5152#include " llvm/ADT/SmallSet.h"
5253#include " llvm/Support/CommandLine.h"
@@ -84,6 +85,9 @@ class DifferentiationTransformer {
8485 // / Context necessary for performing the transformations.
8586 ADContext context;
8687
88+ // / Cache used in getUnwrappedCurryThunkFunction.
89+ llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
90+
8791 // / Promotes the given `differentiable_function` instruction to a valid
8892 // / `@differentiable` function-typed value.
8993 SILValue promoteToDifferentiableFunction (DifferentiableFunctionInst *inst,
@@ -96,6 +100,25 @@ class DifferentiationTransformer {
96100 SILBuilder &builder, SILLocation loc,
97101 DifferentiationInvoker invoker);
98102
103+ // / Emits a reference to a derivative function of `original`, differentiated
104+ // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
105+ // / the derivative function and the actual indices that the derivative
106+ // / function is with respect to.
107+ // /
108+ // / Returns `None` on failure, signifying that a diagnostic has been emitted
109+ // / using `invoker`.
110+ std::optional<std::pair<SILValue, AutoDiffConfig>>
111+ emitDerivativeFunctionReference (
112+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
113+ AutoDiffDerivativeFunctionKind kind, SILValue original,
114+ DifferentiationInvoker invoker,
115+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
116+
117+ // / If the given function corresponds to AutoClosureExpr with either
118+ // / SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
119+ // / corresponding to the function being wrapped in the thunk.
120+ SILFunction *getUnwrappedCurryThunkFunction (SILFunction *originalFn);
121+
99122public:
100123 // / Construct an `DifferentiationTransformer` for the given module.
101124 explicit DifferentiationTransformer (SILModuleTransform &transform)
@@ -453,21 +476,63 @@ static SILValue reapplyFunctionConversion(
453476 llvm_unreachable (" Unhandled function conversion instruction" );
454477}
455478
456- // / Emits a reference to a derivative function of `original`, differentiated
457- // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458- // / the derivative function and the actual indices that the derivative function
459- // / is with respect to.
460- // /
461- // / Returns `None` on failure, signifying that a diagnostic has been emitted
462- // / using `invoker`.
463- static std::optional<std::pair<SILValue, AutoDiffConfig>>
464- emitDerivativeFunctionReference (
465- DifferentiationTransformer &transformer, SILBuilder &builder,
466- const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467- SILValue original, DifferentiationInvoker invoker,
468- SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469- ADContext &context = transformer.getContext ();
479+ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction (
480+ SILFunction *originalFn) {
481+ auto *autoCE = dyn_cast_or_null<AutoClosureExpr>(
482+ originalFn->getDeclRef ().getAbstractClosureExpr ());
483+ if (autoCE == nullptr )
484+ return nullptr ;
485+
486+ auto *ae = dyn_cast_or_null<ApplyExpr>(autoCE->getUnwrappedCurryThunkExpr ());
487+ if (ae == nullptr )
488+ return nullptr ;
470489
490+ AbstractFunctionDecl *afd = cast<AbstractFunctionDecl>(ae->getCalledValue (
491+ /* skipFunctionConversions=*/ true ));
492+ auto silFnIt = afdToSILFn.find (afd);
493+ if (silFnIt == afdToSILFn.end ()) {
494+ assert (afdToSILFn.empty () && " Expect all 'afdToSILFn' cache entries to be "
495+ " filled at once on the first access attempt" );
496+
497+ SILModule *module = getTransform ().getModule ();
498+ for (SILFunction ¤tFunc : module ->getFunctions ()) {
499+ if (auto *currentAFD =
500+ currentFunc.getDeclRef ().getAbstractFunctionDecl ()) {
501+ // Update cache only with AFDs which might be potentially wrapped by a
502+ // curry thunk. This includes member function references and references
503+ // to functions having external property wrapper parameters (see
504+ // ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
505+ // in future, the assertion after the loop will be a trigger for such
506+ // cases being unhandled here.
507+ //
508+ // FIXME: References to functions having external property wrapper
509+ // parameters are not handled since we can't now construct a test case
510+ // for that due to the crash
511+ // https://github.com/swiftlang/swift/issues/77613
512+ if (currentAFD->hasCurriedSelf ()) {
513+ auto [_, wasEmplace] =
514+ afdToSILFn.try_emplace (currentAFD, ¤tFunc);
515+ assert (wasEmplace && " Expect all 'afdToSILFn' cache entries to be "
516+ " filled at once on the first access attempt" );
517+ }
518+ }
519+ }
520+
521+ silFnIt = afdToSILFn.find (afd);
522+ assert (silFnIt != afdToSILFn.end () &&
523+ " Expect present curry thunk to SIL function mapping after "
524+ " 'afdToSILFn' cache fill" );
525+ }
526+
527+ return silFnIt->second ;
528+ }
529+
530+ std::optional<std::pair<SILValue, AutoDiffConfig>>
531+ DifferentiationTransformer::emitDerivativeFunctionReference (
532+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
533+ AutoDiffDerivativeFunctionKind kind, SILValue original,
534+ DifferentiationInvoker invoker,
535+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471536 // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472537 // matches the given kind and desired differentiation parameter indices,
473538 // simply extract the derivative function of its function operand, retain the
@@ -610,26 +675,36 @@ emitDerivativeFunctionReference(
610675 DifferentiabilityKind::Reverse, desiredParameterIndices,
611676 desiredResultIndices, derivativeConstrainedGenSig, /* jvp*/ nullptr ,
612677 /* vjp*/ nullptr , /* isSerialized*/ false );
613- if (transformer. canonicalizeDifferentiabilityWitness (
614- minimalWitness, invoker, IsNotSerialized))
678+ if (canonicalizeDifferentiabilityWitness (minimalWitness, invoker,
679+ IsNotSerialized))
615680 return std::nullopt ;
616681 }
617682 assert (minimalWitness);
618- if (original->getFunction ()->isSerialized () &&
619- !hasPublicVisibility (minimalWitness->getLinkage ())) {
620- enum { Inlinable = 0 , DefaultArgument = 1 };
621- unsigned fragileKind = Inlinable;
622- // FIXME: This is not a very robust way of determining if the function is
623- // a default argument. Also, we have not exhaustively listed all the kinds
624- // of fragility.
625- if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
626- fragileKind = DefaultArgument;
627- context.emitNondifferentiabilityError (
628- original, invoker, diag::autodiff_private_derivative_from_fragile,
629- fragileKind,
630- isa_and_nonnull<AbstractClosureExpr>(
631- originalFRI->getLoc ().getAsASTNode <Expr>()));
632- return std::nullopt ;
683+ if (original->getFunction ()->isSerialized ()) {
684+ // When dealing with curry thunk, look at the function being wrapped
685+ // inside implicit closure. If it has public visibility, the corresponding
686+ // differentiability witness also has public visibility. It should be OK
687+ // for implicit wrapper closure and its witness to have private linkage.
688+ SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction (originalFn);
689+ bool isWitnessPublic =
690+ unwrappedFn == nullptr
691+ ? hasPublicVisibility (minimalWitness->getLinkage ())
692+ : hasPublicVisibility (unwrappedFn->getLinkage ());
693+ if (!isWitnessPublic) {
694+ enum { Inlinable = 0 , DefaultArgument = 1 };
695+ unsigned fragileKind = Inlinable;
696+ // FIXME: This is not a very robust way of determining if the function
697+ // is a default argument. Also, we have not exhaustively listed all the
698+ // kinds of fragility.
699+ if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
700+ fragileKind = DefaultArgument;
701+ context.emitNondifferentiabilityError (
702+ original, invoker, diag::autodiff_private_derivative_from_fragile,
703+ fragileKind,
704+ isa_and_nonnull<AbstractClosureExpr>(
705+ originalFRI->getLoc ().getAsASTNode <Expr>()));
706+ return std::nullopt ;
707+ }
633708 }
634709 // TODO(TF-482): Move generic requirement checking logic to
635710 // `getExactDifferentiabilityWitness` and
@@ -1121,8 +1196,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
11211196 for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
11221197 AutoDiffDerivativeFunctionKind::VJP}) {
11231198 auto derivativeFnAndIndices = emitDerivativeFunctionReference (
1124- * this , builder, desiredConfig, derivativeFnKind, origFnOperand,
1125- invoker, newBuffersToDealloc);
1199+ builder, desiredConfig, derivativeFnKind, origFnOperand, invoker ,
1200+ newBuffersToDealloc);
11261201 // Show an error at the operator, highlight the argument, and show a note
11271202 // at the definition site of the argument.
11281203 if (!derivativeFnAndIndices)
0 commit comments