@@ -282,6 +282,71 @@ struct TypeReprCycleCheckWalker : ASTWalker {
282282
283283}
284284
285+ static bool isExtensionUsableForInference (const ExtensionDecl *extension,
286+ NormalProtocolConformance *conformance) {
287+ // The context the conformance being checked is declared on.
288+ const auto conformanceDC = conformance->getDeclContext ();
289+ if (extension == conformanceDC)
290+ return true ;
291+
292+ // Invalid case.
293+ const auto extendedNominal = extension->getExtendedNominal ();
294+ if (extendedNominal == nullptr )
295+ return true ;
296+
297+ auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
298+
299+ // If the extension is bound to the nominal the conformance is
300+ // declared on, it is viable for inference when its conditional
301+ // requirements are satisfied by those of the conformance context.
302+ if (!proto) {
303+ // Retrieve the generic signature of the extension.
304+ const auto extensionSig = extension->getGenericSignature ();
305+ return extensionSig
306+ .requirementsNotSatisfiedBy (
307+ conformanceDC->getGenericSignatureOfContext ())
308+ .empty ();
309+ }
310+
311+ // The condition here is a bit more fickle than
312+ // `isExtensionApplied`. That check would prematurely reject
313+ // extensions like `P where AssocType == T` if we're relying on a
314+ // default implementation inside the extension to infer `AssocType == T`
315+ // in the first place. Only check conformances on the `Self` type,
316+ // because those have to be explicitly declared on the type somewhere
317+ // so won't be affected by whatever answer inference comes up with.
318+ auto *module = conformanceDC->getParentModule ();
319+ auto checkConformance = [&](ProtocolDecl *proto) {
320+ auto typeInContext = conformanceDC->mapTypeIntoContext (conformance->getType ());
321+ auto otherConf = TypeChecker::conformsToProtocol (
322+ typeInContext, proto, module );
323+ return !otherConf.isInvalid ();
324+ };
325+
326+ // First check the extended protocol itself.
327+ if (!checkConformance (proto))
328+ return false ;
329+
330+ // Source file and module file have different ways to get self bounds.
331+ // Source file extension will have trailing where clause which can avoid
332+ // computing a generic signature. Module file will not have
333+ // trailing where clause, so it will compute generic signature to get
334+ // self bounds which might result in slow performance.
335+ SelfBounds bounds;
336+ if (extension->getParentSourceFile () != nullptr )
337+ bounds = getSelfBoundsFromWhereClause (extension);
338+ else
339+ bounds = getSelfBoundsFromGenericSignature (extension);
340+ for (auto *decl : bounds.decls ) {
341+ if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
342+ if (!checkConformance (proto))
343+ return false ;
344+ }
345+ }
346+
347+ return true ;
348+ }
349+
285350InferredAssociatedTypesByWitnesses
286351AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses (
287352 ConformanceChecker &checker,
@@ -301,70 +366,6 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
301366
302367 InferredAssociatedTypesByWitnesses result;
303368
304- auto isExtensionUsableForInference = [&](const ExtensionDecl *extension) {
305- // The context the conformance being checked is declared on.
306- const auto conformanceCtx = conformance->getDeclContext ();
307- if (extension == conformanceCtx)
308- return true ;
309-
310- // Invalid case.
311- const auto extendedNominal = extension->getExtendedNominal ();
312- if (extendedNominal == nullptr )
313- return true ;
314-
315- auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
316-
317- // If the extension is bound to the nominal the conformance is
318- // declared on, it is viable for inference when its conditional
319- // requirements are satisfied by those of the conformance context.
320- if (!proto) {
321- // Retrieve the generic signature of the extension.
322- const auto extensionSig = extension->getGenericSignature ();
323- return extensionSig
324- .requirementsNotSatisfiedBy (
325- conformanceCtx->getGenericSignatureOfContext ())
326- .empty ();
327- }
328-
329- // The condition here is a bit more fickle than
330- // `isExtensionApplied`. That check would prematurely reject
331- // extensions like `P where AssocType == T` if we're relying on a
332- // default implementation inside the extension to infer `AssocType == T`
333- // in the first place. Only check conformances on the `Self` type,
334- // because those have to be explicitly declared on the type somewhere
335- // so won't be affected by whatever answer inference comes up with.
336- auto *module = dc->getParentModule ();
337- auto checkConformance = [&](ProtocolDecl *proto) {
338- auto typeInContext = dc->mapTypeIntoContext (conformance->getType ());
339- auto otherConf = TypeChecker::conformsToProtocol (
340- typeInContext, proto, module );
341- return !otherConf.isInvalid ();
342- };
343-
344- // First check the extended protocol itself.
345- if (!checkConformance (proto))
346- return false ;
347-
348- // Source file and module file have different ways to get self bounds.
349- // Source file extension will have trailing where clause which can avoid
350- // computing a generic signature. Module file will not have
351- // trailing where clause, so it will compute generic signature to get
352- // self bounds which might result in slow performance.
353- SelfBounds bounds;
354- if (extension->getParentSourceFile () != nullptr )
355- bounds = getSelfBoundsFromWhereClause (extension);
356- else
357- bounds = getSelfBoundsFromGenericSignature (extension);
358- for (auto *decl : bounds.decls ) {
359- if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
360- if (!checkConformance (proto))
361- return false ;
362- }
363- }
364-
365- return true ;
366- };
367-
368369 for (auto witness :
369370 checker.lookupValueWitnesses (req, /* ignoringNames=*/ nullptr )) {
370371 LLVM_DEBUG (llvm::dbgs () << " Inferring associated types from decl:\n " ;
@@ -374,7 +375,7 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
374375 // type can't use it regardless of what associated types we end up
375376 // inferring, skip the witness.
376377 if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext ())) {
377- if (!isExtensionUsableForInference (extension)) {
378+ if (!isExtensionUsableForInference (extension, conformance )) {
378379 LLVM_DEBUG (llvm::dbgs () << " Extension not usable for inference\n " );
379380 continue ;
380381 }
0 commit comments