2525#include " swift/AST/DiagnosticsParse.h"
2626#include " swift/AST/Effects.h"
2727#include " swift/AST/GenericEnvironment.h"
28- #include " swift/AST/GenericSignatureBuilder.h"
2928#include " swift/AST/ImportCache.h"
3029#include " swift/AST/ModuleNameLookup.h"
3130#include " swift/AST/NameLookup.h"
@@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
22312230 return ;
22322231 }
22332232
2234- // Form a new generic signature based on the old one.
2235- GenericSignatureBuilder Builder (D->getASTContext ());
2233+ InferredGenericSignatureRequest request{
2234+ DC->getParentModule (),
2235+ genericSig.getPointer (),
2236+ /* genericParams=*/ nullptr ,
2237+ WhereClauseOwner (FD, attr),
2238+ /* addedRequirements=*/ {},
2239+ /* inferenceSources=*/ {},
2240+ /* allowConcreteGenericParams=*/ true };
22362241
2237- // First, add the old generic signature.
2238- Builder.addGenericSignature (genericSig);
2239-
2240- // Go over the set of requirements, adding them to the builder.
2241- WhereClauseOwner (FD, attr).visitRequirements (TypeResolutionStage::Interface,
2242- [&](const Requirement &req, RequirementRepr *reqRepr) {
2243- // Add the requirement to the generic signature builder.
2244- using FloatingRequirementSource =
2245- GenericSignatureBuilder::FloatingRequirementSource;
2246- Builder.addRequirement (req, reqRepr,
2247- FloatingRequirementSource::forExplicit (
2248- reqRepr->getSeparatorLoc ()),
2249- nullptr , DC->getParentModule ());
2250- return false ;
2251- });
2252-
2253- // Check the result.
2254- auto specializedSig = std::move (Builder).computeGenericSignature (
2255- /* allowConcreteGenericParams=*/ true );
2242+ auto specializedSig = evaluateOrDefault (Ctx.evaluator , request,
2243+ GenericSignature ());
22562244
22572245 // Check the validity of provided requirements.
22582246 checkSpecializeAttrRequirements (attr, genericSig, specializedSig, Ctx);
@@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
42664254 // - If the `@differentiable` attribute has a `where` clause, use it to
42674255 // compute the derivative generic signature.
42684256 // - Otherwise, use the original function's generic signature by default.
4269- derivativeGenSig = original->getGenericSignature ();
4257+ auto originalGenSig = original->getGenericSignature ();
4258+ derivativeGenSig = originalGenSig;
42704259
42714260 // Handle the `where` clause, if it exists.
42724261 // - Resolve attribute where clause requirements and store in the attribute
@@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
42914280 return true ;
42924281 }
42934282
4294- auto originalGenSig = original->getGenericSignature ();
42954283 if (!originalGenSig) {
42964284 // `where` clauses are valid only when the original function is generic.
42974285 diags
@@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
43044292 return true ;
43054293 }
43064294
4307- // Build a new generic signature for autodiff derivative functions.
4308- GenericSignatureBuilder builder (ctx);
4309- // Add the original function's generic signature.
4310- builder.addGenericSignature (originalGenSig);
4311-
4312- using FloatingRequirementSource =
4313- GenericSignatureBuilder::FloatingRequirementSource;
4314-
4315- bool errorOccurred = false ;
4316- WhereClauseOwner (original, attr)
4317- .visitRequirements (
4318- TypeResolutionStage::Structural,
4319- [&](const Requirement &req, RequirementRepr *reqRepr) {
4320- switch (req.getKind ()) {
4321- case RequirementKind::SameType:
4322- case RequirementKind::Superclass:
4323- case RequirementKind::Conformance:
4324- break ;
4325-
4326- // Layout requirements are not supported.
4327- case RequirementKind::Layout:
4328- diags
4329- .diagnose (attr->getLocation (),
4330- diag::differentiable_attr_layout_req_unsupported)
4331- .highlight (reqRepr->getSourceRange ());
4332- errorOccurred = true ;
4333- return false ;
4334- }
4295+ InferredGenericSignatureRequest request{
4296+ original->getParentModule (),
4297+ originalGenSig.getPointer (),
4298+ /* genericParams=*/ nullptr ,
4299+ WhereClauseOwner (original, attr),
4300+ /* addedRequirements=*/ {},
4301+ /* inferenceSources=*/ {},
4302+ /* allowConcreteParams=*/ true };
4303+
4304+ // Compute generic signature for derivative functions.
4305+ derivativeGenSig = evaluateOrDefault (ctx.evaluator , request,
4306+ GenericSignature ());
43354307
4336- // Add requirement to generic signature builder.
4337- builder.addRequirement (
4338- req, reqRepr, FloatingRequirementSource::forExplicit (
4339- reqRepr->getSeparatorLoc ()),
4340- nullptr , original->getModuleContext ());
4341- return false ;
4342- });
4308+ bool hadInvalidRequirements = false ;
4309+ for (auto req : derivativeGenSig.requirementsNotSatisfiedBy (originalGenSig)) {
4310+ if (req.getKind () == RequirementKind::Layout) {
4311+ // Layout requirements are not supported.
4312+ diags
4313+ .diagnose (attr->getLocation (),
4314+ diag::differentiable_attr_layout_req_unsupported);
4315+ hadInvalidRequirements = true ;
4316+ }
4317+ }
43434318
4344- if (errorOccurred ) {
4319+ if (hadInvalidRequirements ) {
43454320 attr->setInvalid ();
43464321 return true ;
43474322 }
4348-
4349- // Compute generic signature for derivative functions.
4350- derivativeGenSig = std::move (builder).computeGenericSignature (
4351- /* allowConcreteGenericParams=*/ true );
43524323 }
43534324
43544325 attr->setDerivativeGenericSignature (derivativeGenSig);
0 commit comments