@@ -161,6 +161,192 @@ static bool associatedTypesAreSameEquivalenceClass(AssociatedTypeDecl *a,
161161 return false ;
162162}
163163
164+ namespace {
165+
166+ // / Try to avoid situations where resolving the type of a witness calls back
167+ // / into associated type inference.
168+ struct TypeReprCycleCheckWalker : ASTWalker {
169+ llvm::SmallDenseSet<Identifier, 2 > circularNames;
170+ ValueDecl *witness;
171+ bool found;
172+
173+ TypeReprCycleCheckWalker (
174+ const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved)
175+ : witness(nullptr ), found(false ) {
176+ for (auto *assocType : allUnresolved) {
177+ circularNames.insert (assocType->getName ());
178+ }
179+ }
180+
181+ PreWalkAction walkToTypeReprPre (TypeRepr *T) override {
182+ // FIXME: We should still visit any generic arguments of this member type.
183+ // However, we want to skip 'Foo.Element' because the 'Element' reference is
184+ // not unqualified.
185+ if (auto *memberTyR = dyn_cast<MemberTypeRepr>(T)) {
186+ return Action::SkipChildren ();
187+ }
188+
189+ if (auto *identTyR = dyn_cast<SimpleIdentTypeRepr>(T)) {
190+ if (circularNames.count (identTyR->getNameRef ().getBaseIdentifier ()) > 0 ) {
191+ // If unqualified lookup can find a type with this name without looking
192+ // into protocol members, don't skip the witness, since this type might
193+ // be a candidate witness.
194+ auto desc = UnqualifiedLookupDescriptor (
195+ identTyR->getNameRef (), witness->getDeclContext (),
196+ identTyR->getLoc (), UnqualifiedLookupOptions ());
197+
198+ auto &ctx = witness->getASTContext ();
199+ auto results =
200+ evaluateOrDefault (ctx.evaluator , UnqualifiedLookupRequest{desc}, {});
201+
202+ // Ok, resolving this name would trigger associated type inference
203+ // recursively. We're going to skip this witness.
204+ if (results.allResults ().empty ()) {
205+ found = true ;
206+ return Action::Stop ();
207+ }
208+ }
209+ }
210+
211+ return Action::Continue ();
212+ }
213+
214+ bool checkForPotentialCycle (ValueDecl *witness) {
215+ // Don't do this for protocol extension members, because we have a
216+ // mini "solver" that avoids similar issues instead.
217+ if (witness->getDeclContext ()->getSelfProtocolDecl () != nullptr )
218+ return false ;
219+
220+ // If we already have an interface type, don't bother trying to
221+ // avoid a cycle.
222+ if (witness->hasInterfaceType ())
223+ return false ;
224+
225+ // We call checkForPotentailCycle() multiple times with
226+ // different witnesses.
227+ found = false ;
228+ this ->witness = witness;
229+
230+ auto walkInto = [&](TypeRepr *tyR) {
231+ if (tyR)
232+ tyR->walk (*this );
233+ return found;
234+ };
235+
236+ if (auto *AFD = dyn_cast<AbstractFunctionDecl>(witness)) {
237+ for (auto *param : *AFD->getParameters ()) {
238+ if (walkInto (param->getTypeRepr ()))
239+ return true ;
240+ }
241+
242+ if (auto *FD = dyn_cast<FuncDecl>(witness)) {
243+ if (walkInto (FD->getResultTypeRepr ()))
244+ return true ;
245+ }
246+
247+ return false ;
248+ }
249+
250+ if (auto *SD = dyn_cast<SubscriptDecl>(witness)) {
251+ for (auto *param : *SD->getIndices ()) {
252+ if (walkInto (param->getTypeRepr ()))
253+ return true ;
254+ }
255+
256+ if (walkInto (SD->getElementTypeRepr ()))
257+ return true ;
258+
259+ return false ;
260+ }
261+
262+ if (auto *VD = dyn_cast<VarDecl>(witness)) {
263+ if (walkInto (VD->getTypeReprOrParentPatternTypeRepr ()))
264+ return true ;
265+
266+ return false ;
267+ }
268+
269+ if (auto *EED = dyn_cast<EnumElementDecl>(witness)) {
270+ for (auto *param : *EED->getParameterList ()) {
271+ if (walkInto (param->getTypeRepr ()))
272+ return true ;
273+ }
274+
275+ return false ;
276+ }
277+
278+ assert (false && " Should be exhaustive" );
279+ return false ;
280+ }
281+ };
282+
283+ }
284+
285+ static bool isExtensionUsableForInference (const ExtensionDecl *extension,
286+ NormalProtocolConformance *conformance) {
287+ // The context the conformance being checked is declared on.
288+ const auto conformanceDC = conformance->getDeclContext ();
289+ if (extension == conformanceDC)
290+ return true ;
291+
292+ // Invalid case.
293+ const auto extendedNominal = extension->getExtendedNominal ();
294+ if (extendedNominal == nullptr )
295+ return true ;
296+
297+ auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
298+
299+ // If the extension is bound to the nominal the conformance is
300+ // declared on, it is viable for inference when its conditional
301+ // requirements are satisfied by those of the conformance context.
302+ if (!proto) {
303+ // Retrieve the generic signature of the extension.
304+ const auto extensionSig = extension->getGenericSignature ();
305+ return extensionSig
306+ .requirementsNotSatisfiedBy (
307+ conformanceDC->getGenericSignatureOfContext ())
308+ .empty ();
309+ }
310+
311+ // The condition here is a bit more fickle than
312+ // `isExtensionApplied`. That check would prematurely reject
313+ // extensions like `P where AssocType == T` if we're relying on a
314+ // default implementation inside the extension to infer `AssocType == T`
315+ // in the first place. Only check conformances on the `Self` type,
316+ // because those have to be explicitly declared on the type somewhere
317+ // so won't be affected by whatever answer inference comes up with.
318+ auto *module = conformanceDC->getParentModule ();
319+ auto checkConformance = [&](ProtocolDecl *proto) {
320+ auto typeInContext = conformanceDC->mapTypeIntoContext (conformance->getType ());
321+ auto otherConf = TypeChecker::conformsToProtocol (
322+ typeInContext, proto, module );
323+ return !otherConf.isInvalid ();
324+ };
325+
326+ // First check the extended protocol itself.
327+ if (!checkConformance (proto))
328+ return false ;
329+
330+ // Source file and module file have different ways to get self bounds.
331+ // Source file extension will have trailing where clause which can avoid
332+ // computing a generic signature. Module file will not have
333+ // trailing where clause, so it will compute generic signature to get
334+ // self bounds which might result in slow performance.
335+ SelfBounds bounds;
336+ if (extension->getParentSourceFile () != nullptr )
337+ bounds = getSelfBoundsFromWhereClause (extension);
338+ else
339+ bounds = getSelfBoundsFromGenericSignature (extension);
340+ for (auto *decl : bounds.decls ) {
341+ if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
342+ if (!checkConformance (proto))
343+ return false ;
344+ }
345+ }
346+
347+ return true ;
348+ }
349+
164350InferredAssociatedTypesByWitnesses
165351AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses (
166352 ConformanceChecker &checker,
@@ -176,71 +362,9 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
176362 abort ();
177363 }
178364
179- InferredAssociatedTypesByWitnesses result ;
365+ TypeReprCycleCheckWalker cycleCheck (allUnresolved) ;
180366
181- auto isExtensionUsableForInference = [&](const ExtensionDecl *extension) {
182- // The context the conformance being checked is declared on.
183- const auto conformanceCtx = checker.Conformance ->getDeclContext ();
184- if (extension == conformanceCtx)
185- return true ;
186-
187- // Invalid case.
188- const auto extendedNominal = extension->getExtendedNominal ();
189- if (extendedNominal == nullptr )
190- return true ;
191-
192- auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
193-
194- // If the extension is bound to the nominal the conformance is
195- // declared on, it is viable for inference when its conditional
196- // requirements are satisfied by those of the conformance context.
197- if (!proto) {
198- // Retrieve the generic signature of the extension.
199- const auto extensionSig = extension->getGenericSignature ();
200- return extensionSig
201- .requirementsNotSatisfiedBy (
202- conformanceCtx->getGenericSignatureOfContext ())
203- .empty ();
204- }
205-
206- // The condition here is a bit more fickle than
207- // `isExtensionApplied`. That check would prematurely reject
208- // extensions like `P where AssocType == T` if we're relying on a
209- // default implementation inside the extension to infer `AssocType == T`
210- // in the first place. Only check conformances on the `Self` type,
211- // because those have to be explicitly declared on the type somewhere
212- // so won't be affected by whatever answer inference comes up with.
213- auto *module = dc->getParentModule ();
214- auto checkConformance = [&](ProtocolDecl *proto) {
215- auto typeInContext = dc->mapTypeIntoContext (conformance->getType ());
216- auto otherConf = TypeChecker::conformsToProtocol (
217- typeInContext, proto, module );
218- return !otherConf.isInvalid ();
219- };
220-
221- // First check the extended protocol itself.
222- if (!checkConformance (proto))
223- return false ;
224-
225- // Source file and module file have different ways to get self bounds.
226- // Source file extension will have trailing where clause which can avoid
227- // computing a generic signature. Module file will not have
228- // trailing where clause, so it will compute generic signature to get
229- // self bounds which might result in slow performance.
230- SelfBounds bounds;
231- if (extension->getParentSourceFile () != nullptr )
232- bounds = getSelfBoundsFromWhereClause (extension);
233- else
234- bounds = getSelfBoundsFromGenericSignature (extension);
235- for (auto *decl : bounds.decls ) {
236- if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
237- if (!checkConformance (proto))
238- return false ;
239- }
240- }
241-
242- return true ;
243- };
367+ InferredAssociatedTypesByWitnesses result;
244368
245369 for (auto witness :
246370 checker.lookupValueWitnesses (req, /* ignoringNames=*/ nullptr )) {
@@ -250,11 +374,17 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
250374 // If the potential witness came from an extension, and our `Self`
251375 // type can't use it regardless of what associated types we end up
252376 // inferring, skip the witness.
253- if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext ()))
254- if (!isExtensionUsableForInference (extension)) {
377+ if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext ())) {
378+ if (!isExtensionUsableForInference (extension, conformance )) {
255379 LLVM_DEBUG (llvm::dbgs () << " Extension not usable for inference\n " );
256380 continue ;
257381 }
382+ }
383+
384+ if (cycleCheck.checkForPotentialCycle (witness)) {
385+ LLVM_DEBUG (llvm::dbgs () << " Skipping witness to avoid request cycle\n " );
386+ continue ;
387+ }
258388
259389 // Try to resolve the type witness via this value witness.
260390 auto witnessResult = inferTypeWitnessesViaValueWitness (req, witness);
0 commit comments