@@ -5214,49 +5214,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
52145214 return originalType;
52155215}
52165216
5217- // / Given a `@differentiable` attribute, attempts to resolve the original
5218- // / `AbstractFunctionDecl` for which it is registered, using the declaration
5219- // / on which it is actually declared. On error, emits diagnostic and returns
5220- // / `nullptr`.
5221- AbstractFunctionDecl *
5222- resolveDifferentiableAttrOriginalFunction (DifferentiableAttr *attr) {
5223- auto *D = attr->getOriginalDeclaration ();
5224- assert (D &&
5225- " Original declaration should be resolved by parsing/deserialization" );
5226- auto *original = dyn_cast<AbstractFunctionDecl>(D);
5227- if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5228- // If `@differentiable` attribute is declared directly on a
5229- // `AbstractStorageDecl` (a stored/computed property or subscript),
5230- // forward the attribute to the storage's getter.
5231- // TODO(TF-129): Forward `@differentiable` attributes to setters after
5232- // differentiation supports inout parameters.
5233- // TODO(TF-1080): Forward `@differentiable` attributes to `read` and
5234- // `modify` accessors after differentiation supports `inout` parameters.
5235- if (!asd->getDeclContext ()->isModuleScopeContext ()) {
5236- original = asd->getSynthesizedAccessor (AccessorKind::Get);
5237- } else {
5238- original = nullptr ;
5239- }
5240- }
5241- // Non-`get` accessors are not yet supported: `set`, `read`, and `modify`.
5242- // TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5243- // coroutines.
5244- if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5245- if (!accessor->isGetter () && !accessor->isSetter ())
5246- original = nullptr ;
5247- // Diagnose if original `AbstractFunctionDecl` could not be resolved.
5248- if (!original) {
5249- diagnoseAndRemoveAttr (D, attr, diag::invalid_decl_attribute, attr);
5250- attr->setInvalid ();
5251- return nullptr ;
5252- }
5253- // If the original function has an error interface type, return.
5254- // A diagnostic should have already been emitted.
5255- if (original->getInterfaceType ()->hasError ())
5256- return nullptr ;
5257- return original;
5258- }
5259-
52605217// / Given a `@differentiable` attribute, attempts to resolve the derivative
52615218// / generic signature. The derivative generic signature is returned as
52625219// / `derivativeGenSig`. On error, emits diagnostic, assigns `nullptr` to
@@ -5435,11 +5392,11 @@ bool resolveDifferentiableAttrDifferentiabilityParameters(
54355392
54365393// / Checks whether differentiable programming is enabled for the given
54375394// / differentiation-related attribute. Returns true on error.
5438- bool checkIfDifferentiableProgrammingEnabled (ASTContext &ctx ,
5439- DeclAttribute *attr,
5440- DeclContext *DC) {
5395+ static bool checkIfDifferentiableProgrammingEnabled (DeclAttribute *attr ,
5396+ Decl *D) {
5397+ auto &ctx = D-> getASTContext ();
54415398 auto &diags = ctx.Diags ;
5442- auto *SF = DC ->getParentSourceFile ();
5399+ auto *SF = D-> getDeclContext () ->getParentSourceFile ();
54435400 assert (SF && " Source file not found" );
54445401 // The `Differentiable` protocol must be available.
54455402 // If unavailable, the `_Differentiation` module should be imported.
@@ -5452,31 +5409,36 @@ bool checkIfDifferentiableProgrammingEnabled(ASTContext &ctx,
54525409 return true ;
54535410}
54545411
5455- IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate (
5456- Evaluator &evaluator, DifferentiableAttr *attr) const {
5457- // Skip type-checking for implicit `@differentiable` attributes. We currently
5458- // assume that all implicit `@differentiable` attributes are valid.
5459- //
5460- // Motivation: some implicit attributes do not have a `where` clause, and this
5461- // function assumes that the `where` clauses exist. Propagating `where`
5462- // clauses and requirements consistently is a larger problem, to be revisited.
5463- if (attr->isImplicit ())
5464- return nullptr ;
5412+ static IndexSubset *
5413+ resolveDiffParamIndices (AbstractFunctionDecl *original,
5414+ DifferentiableAttr *attr,
5415+ GenericSignature derivativeGenSig) {
5416+ auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment ();
54655417
5466- auto *D = attr->getOriginalDeclaration ();
5467- auto &ctx = D->getASTContext ();
5468- auto &diags = ctx.Diags ;
5469- // `@differentiable` attribute requires experimental differentiable
5470- // programming to be enabled.
5471- if (checkIfDifferentiableProgrammingEnabled (ctx, attr, D->getDeclContext ()))
5472- return nullptr ;
5418+ // Compute the derivative function type.
5419+ auto originalFnRemappedTy = original->getInterfaceType ()->castTo <AnyFunctionType>();
5420+ if (derivativeGenEnv)
5421+ originalFnRemappedTy =
5422+ derivativeGenEnv->mapTypeIntoContext (originalFnRemappedTy)
5423+ ->castTo <AnyFunctionType>();
54735424
5474- // Resolve the original `AbstractFunctionDecl`.
5475- auto *original = resolveDifferentiableAttrOriginalFunction (attr);
5476- if (!original)
5425+ // Resolve and validate the differentiability parameters.
5426+ IndexSubset *resolvedDiffParamIndices = nullptr ;
5427+ if (resolveDifferentiableAttrDifferentiabilityParameters (
5428+ attr, original, originalFnRemappedTy, derivativeGenEnv,
5429+ resolvedDiffParamIndices))
54775430 return nullptr ;
54785431
5479- auto *originalFnTy = original->getInterfaceType ()->castTo <AnyFunctionType>();
5432+ return resolvedDiffParamIndices;
5433+ }
5434+
5435+
5436+ static IndexSubset *
5437+ typecheckDifferentiableAttrforDecl (AbstractFunctionDecl *original,
5438+ DifferentiableAttr *attr,
5439+ IndexSubset *resolvedDiffParamIndices = nullptr ) {
5440+ auto &ctx = original->getASTContext ();
5441+ auto &diags = ctx.Diags ;
54805442
54815443 // Diagnose if original function has opaque result types.
54825444 if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl ()) {
@@ -5523,69 +5485,161 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
55235485 }
55245486
55255487 // Resolve the derivative generic signature.
5526- GenericSignature derivativeGenSig = nullptr ;
5527- if (resolveDifferentiableAttrDerivativeGenericSignature (attr, original,
5488+ GenericSignature derivativeGenSig = attr->getDerivativeGenericSignature ();
5489+ if (!derivativeGenSig &&
5490+ resolveDifferentiableAttrDerivativeGenericSignature (attr, original,
55285491 derivativeGenSig))
55295492 return nullptr ;
5530- auto *derivativeGenEnv = derivativeGenSig.getGenericEnvironment ();
5531-
5532- // Compute the derivative function type.
5533- auto originalFnRemappedTy = originalFnTy;
5534- if (derivativeGenEnv)
5535- originalFnRemappedTy =
5536- derivativeGenEnv->mapTypeIntoContext (originalFnRemappedTy)
5537- ->castTo <AnyFunctionType>();
55385493
55395494 // Resolve and validate the differentiability parameters.
5540- IndexSubset * resolvedDiffParamIndices = nullptr ;
5541- if ( resolveDifferentiableAttrDifferentiabilityParameters (
5542- attr, original, originalFnRemappedTy, derivativeGenEnv,
5543- resolvedDiffParamIndices) )
5495+ if (! resolvedDiffParamIndices)
5496+ resolvedDiffParamIndices = resolveDiffParamIndices (original, attr,
5497+ derivativeGenSig);
5498+ if (! resolvedDiffParamIndices)
55445499 return nullptr ;
55455500
5546- if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
5547- // Remove `@differentiable` attribute from storage declaration to prevent
5548- // duplicate attribute registration during SILGen.
5549- D->getAttrs ().removeAttribute (attr);
5550- // Transfer `@differentiable` attribute from storage declaration to
5551- // getter accessor.
5552- auto *getterDecl = asd->getOpaqueAccessor (AccessorKind::Get);
5553- auto *newAttr = DifferentiableAttr::create (
5554- getterDecl, /* implicit*/ true , attr->AtLoc , attr->getRange (),
5555- attr->getDifferentiabilityKind (), resolvedDiffParamIndices,
5556- attr->getDerivativeGenericSignature ());
5557- auto insertion = ctx.DifferentiableAttrs .try_emplace (
5558- {getterDecl, resolvedDiffParamIndices}, newAttr);
5559- // Reject duplicate `@differentiable` attributes.
5560- if (!insertion.second ) {
5561- diagnoseAndRemoveAttr (D, attr, diag::differentiable_attr_duplicate);
5562- diags.diagnose (insertion.first ->getSecond ()->getLocation (),
5563- diag::differentiable_attr_duplicate_note);
5564- return nullptr ;
5565- }
5566- getterDecl->getAttrs ().add (newAttr);
5567- // Register derivative function configuration.
5568- auto *resultIndices = IndexSubset::get (ctx, 1 , {0 });
5569- getterDecl->addDerivativeFunctionConfiguration (
5570- {resolvedDiffParamIndices, resultIndices, derivativeGenSig});
5571- return resolvedDiffParamIndices;
5572- }
55735501 // Reject duplicate `@differentiable` attributes.
55745502 auto insertion =
5575- ctx.DifferentiableAttrs .try_emplace ({D , resolvedDiffParamIndices}, attr);
5503+ ctx.DifferentiableAttrs .try_emplace ({original , resolvedDiffParamIndices}, attr);
55765504 if (!insertion.second && insertion.first ->getSecond () != attr) {
5577- diagnoseAndRemoveAttr (D , attr, diag::differentiable_attr_duplicate);
5505+ diagnoseAndRemoveAttr (original , attr, diag::differentiable_attr_duplicate);
55785506 diags.diagnose (insertion.first ->getSecond ()->getLocation (),
55795507 diag::differentiable_attr_duplicate_note);
55805508 return nullptr ;
55815509 }
5510+
55825511 // Register derivative function configuration.
55835512 auto *resultIndices = IndexSubset::get (ctx, 1 , {0 });
55845513 original->addDerivativeFunctionConfiguration (
55855514 {resolvedDiffParamIndices, resultIndices, derivativeGenSig});
55865515 return resolvedDiffParamIndices;
55875516}
55885517
5518+ // / Given a `@differentiable` attribute, attempts to resolve the original
5519+ // / `AbstractFunctionDecl` for which it is registered, using the declaration
5520+ // / on which it is actually declared. On error, emits diagnostic and returns
5521+ // / `nullptr`.
5522+ static AbstractFunctionDecl *
5523+ resolveDifferentiableAttrOriginalFunction (DifferentiableAttr *attr) {
5524+ auto *D = attr->getOriginalDeclaration ();
5525+ auto *original = dyn_cast<AbstractFunctionDecl>(D);
5526+
5527+ // Non-`get`/`set` accessors are not yet supported: `read`, and `modify`.
5528+ // TODO(TF-1080): Enable `read` and `modify` when differentiation supports
5529+ // coroutines.
5530+ if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original))
5531+ if (!accessor->isGetter () && !accessor->isSetter ())
5532+ original = nullptr ;
5533+
5534+ // Diagnose if original `AbstractFunctionDecl` could not be resolved.
5535+ if (!original) {
5536+ diagnoseAndRemoveAttr (D, attr, diag::invalid_decl_attribute, attr);
5537+ attr->setInvalid ();
5538+ return nullptr ;
5539+ }
5540+
5541+ // If the original function has an error interface type, return.
5542+ // A diagnostic should have already been emitted.
5543+ if (original->getInterfaceType ()->hasError ())
5544+ return nullptr ;
5545+
5546+ return original;
5547+ }
5548+
5549+ static IndexSubset *
5550+ resolveDifferentiableAccessors (DifferentiableAttr *attr,
5551+ AbstractStorageDecl *asd) {
5552+ auto typecheckAccessor = [&](AccessorDecl *ad) -> IndexSubset* {
5553+ GenericSignature derivativeGenSig = nullptr ;
5554+ if (resolveDifferentiableAttrDerivativeGenericSignature (attr, ad,
5555+ derivativeGenSig))
5556+ return nullptr ;
5557+
5558+ IndexSubset *resolvedDiffParamIndices = resolveDiffParamIndices (ad, attr,
5559+ derivativeGenSig);
5560+ if (!resolvedDiffParamIndices)
5561+ return nullptr ;
5562+
5563+ auto *newAttr = DifferentiableAttr::create (
5564+ ad, /* implicit*/ true , attr->AtLoc , attr->getRange (),
5565+ attr->getDifferentiabilityKind (), resolvedDiffParamIndices,
5566+ attr->getDerivativeGenericSignature ());
5567+ ad->getAttrs ().add (newAttr);
5568+
5569+ if (!typecheckDifferentiableAttrforDecl (ad, attr,
5570+ resolvedDiffParamIndices))
5571+ return nullptr ;
5572+
5573+ return resolvedDiffParamIndices;
5574+ };
5575+
5576+ // No getters / setters for global variables
5577+ if (asd->getDeclContext ()->isModuleScopeContext ()) {
5578+ diagnoseAndRemoveAttr (asd, attr, diag::invalid_decl_attribute, attr);
5579+ attr->setInvalid ();
5580+ return nullptr ;
5581+ }
5582+
5583+ if (!typecheckAccessor (asd->getSynthesizedAccessor (AccessorKind::Get)))
5584+ return nullptr ;
5585+
5586+ if (asd->supportsMutation ()) {
5587+ // FIXME: Class-typed values have reference semantics and can be freely
5588+ // mutated. Thus, they should be treated like inout parameters for the
5589+ // purposes of @differentiable and @derivative type-checking. Until
5590+ // https://github.com/apple/swift/issues/55542 is fixed, check if setter has
5591+ // computed semantic results and do not typecheck if they are none
5592+ // (class-typed `self' parameter is not treated as a "semantic result"
5593+ // currently)
5594+ if (!asd->getDeclContext ()->getSelfClassDecl ())
5595+ if (!typecheckAccessor (asd->getSynthesizedAccessor (AccessorKind::Set)))
5596+ return nullptr ;
5597+ }
5598+
5599+ // Remove `@differentiable` attribute from storage declaration to prevent
5600+ // duplicate attribute registration during SILGen.
5601+ asd->getAttrs ().removeAttribute (attr);
5602+
5603+ // Here we are effectively removing attribute from original decl, therefore no
5604+ // index subset for us
5605+ return nullptr ;
5606+ }
5607+
5608+
5609+ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate (
5610+ Evaluator &evaluator, DifferentiableAttr *attr) const {
5611+ // Skip type-checking for implicit `@differentiable` attributes. We currently
5612+ // assume that all implicit `@differentiable` attributes are valid.
5613+ //
5614+ // Motivation: some implicit attributes do not have a `where` clause, and this
5615+ // function assumes that the `where` clauses exist. Propagating `where`
5616+ // clauses and requirements consistently is a larger problem, to be revisited.
5617+ if (attr->isImplicit ())
5618+ return nullptr ;
5619+
5620+ auto *D = attr->getOriginalDeclaration ();
5621+ assert (D &&
5622+ " Original declaration should be resolved by parsing/deserialization" );
5623+
5624+ // `@differentiable` attribute requires experimental differentiable
5625+ // programming to be enabled.
5626+ if (checkIfDifferentiableProgrammingEnabled (attr, D))
5627+ return nullptr ;
5628+
5629+ // If `@differentiable` attribute is declared directly on a
5630+ // `AbstractStorageDecl` (a stored/computed property or subscript),
5631+ // forward the attribute to the storage's getter / setter
5632+ if (auto *asd = dyn_cast<AbstractStorageDecl>(D))
5633+ return resolveDifferentiableAccessors (attr, asd);
5634+
5635+ // Resolve the original `AbstractFunctionDecl`.
5636+ auto *original = resolveDifferentiableAttrOriginalFunction (attr);
5637+ if (!original)
5638+ return nullptr ;
5639+
5640+ return typecheckDifferentiableAttrforDecl (original, attr);
5641+ }
5642+
55895643void AttributeChecker::visitDifferentiableAttr (DifferentiableAttr *attr) {
55905644 // Call `getParameterIndices` to trigger
55915645 // `DifferentiableAttributeTypeCheckRequest`.
@@ -5608,7 +5662,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
56085662 auto &diags = Ctx.Diags ;
56095663 // `@derivative` attribute requires experimental differentiable programming
56105664 // to be enabled.
5611- if (checkIfDifferentiableProgrammingEnabled (Ctx, attr, D-> getDeclContext () ))
5665+ if (checkIfDifferentiableProgrammingEnabled (attr, D))
56125666 return true ;
56135667 auto *derivative = cast<FuncDecl>(D);
56145668 auto originalName = attr->getOriginalFunctionName ();
0 commit comments