@@ -1382,46 +1382,65 @@ enum class InferenceCandidateKind {
13821382static InferenceCandidateKind checkInferenceCandidate (
13831383 std::pair<AssociatedTypeDecl *, Type> *result,
13841384 NormalProtocolConformance *conformance,
1385- DeclContext *witnessDC ,
1385+ ValueDecl *witness ,
13861386 Type selfTy) {
13871387 auto &ctx = selfTy->getASTContext ();
13881388
1389+ // The unbound form of `Self.A`.
1390+ auto selfAssocTy = DependentMemberType::get (selfTy, result->first ->getName ());
1391+ auto genericSig = witness->getInnermostDeclContext ()
1392+ ->getGenericSignatureOfContext ();
1393+
1394+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1395+ // If the witness is in a protocol extension for a completely unrelated
1396+ // protocol that doesn't declare an associated type with the same name as
1397+ // the one we are trying to infer, then it will never be tautological.
1398+ if (!genericSig->isValidTypeParameter (selfAssocTy))
1399+ return InferenceCandidateKind::Good;
1400+ }
1401+
1402+ // A tautological binding is one where the left-hand side has the same
1403+ // reduced type as the right-hand side in the generic signature of the
1404+ // witness.
13891405 auto isTautological = [&](Type t) -> bool {
1406+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1407+
13901408 auto dmt = t->getAs <DependentMemberType>();
13911409 if (!dmt)
13921410 return false ;
1393- if (!associatedTypesAreSameEquivalenceClass (dmt->getAssocType (),
1394- result->first ))
1395- return false ;
13961411
1397- Type typeInContext;
1398- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1399-
1400- typeInContext = selfTy;
1412+ return genericSig->areReducedTypeParametersEqual (dmt, selfAssocTy);
14011413
14021414 } else {
14031415
1404- typeInContext =
1405- conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
1406-
1407- }
1416+ auto dmt = t->getAs <DependentMemberType>();
1417+ if (!dmt)
1418+ return false ;
1419+ if (!associatedTypesAreSameEquivalenceClass (dmt->getAssocType (),
1420+ result->first ))
1421+ return false ;
14081422
1423+ Type typeInContext =
1424+ conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
14091425 if (!dmt->getBase ()->isEqual (typeInContext))
14101426 return false ;
14111427
14121428 return true ;
1429+
1430+ }
14131431 };
14141432
14151433 // Self.X == Self.X doesn't give us any new information, nor does it
14161434 // immediately fail.
14171435 if (isTautological (result->second )) {
1418- auto *dmt = result->second ->castTo <DependentMemberType>();
1419-
1420- auto selfAssocTy = DependentMemberType::get (selfTy, dmt->getAssocType ());
1436+ // FIXME: This should be getInnermostDeclContext()->getGenericSignature(),
1437+ // but that might introduce new ambiguities in existing code so we need
1438+ // to be careful.
1439+ auto genericSig = witness->getDeclContext ()->getGenericSignatureOfContext ();
14211440
14221441 // If we have a same-type requirement `Self.X == Self.Y`,
14231442 // introduce a binding `Self.X := Self.Y`.
1424- for (auto &reqt : witnessDC-> getGenericSignatureOfContext () .getRequirements ()) {
1443+ for (auto &reqt : genericSig .getRequirements ()) {
14251444 switch (reqt.getKind ()) {
14261445 case RequirementKind::SameShape:
14271446 llvm_unreachable (" Same-shape requirement not supported here" );
@@ -1432,6 +1451,45 @@ static InferenceCandidateKind checkInferenceCandidate(
14321451 break ;
14331452
14341453 case RequirementKind::SameType:
1454+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1455+
1456+ auto matches = [&](Type t) {
1457+ if (auto *dmt = t->getAs <DependentMemberType>()) {
1458+ return (dmt->getName () == result->first ->getName () &&
1459+ dmt->getBase ()->isEqual (selfTy));
1460+ }
1461+
1462+ return false ;
1463+ };
1464+
1465+ // If we have a tautological binding, check if the witness generic
1466+ // signature has a same-type requirement `Self.A == Self.X` or
1467+ // `Self.X == Self.A`, where `A` is an associated type with the same
1468+ // name as the one we're trying to infer, and `X` is some other type
1469+ // parameter.
1470+ Type other;
1471+ if (matches (reqt.getFirstType ())) {
1472+ other = reqt.getSecondType ();
1473+ } else if (matches (reqt.getSecondType ())) {
1474+ other = reqt.getFirstType ();
1475+ } else {
1476+ break ;
1477+ }
1478+
1479+ if (other->isTypeParameter () &&
1480+ other->getRootGenericParam ()->isEqual (selfTy)) {
1481+ result->second = other;
1482+ LLVM_DEBUG (llvm::dbgs () << " ++ we can same-type to:\n " ;
1483+ result->second ->dump (llvm::dbgs ()));
1484+ return InferenceCandidateKind::Good;
1485+
1486+ }
1487+
1488+ } else {
1489+
1490+ auto *dmt = result->second ->castTo <DependentMemberType>();
1491+ auto selfAssocTy = DependentMemberType::get (selfTy, dmt->getAssocType ());
1492+
14351493 Type other;
14361494 if (reqt.getFirstType ()->isEqual (selfAssocTy)) {
14371495 other = reqt.getSecondType ();
@@ -1443,18 +1501,9 @@ static InferenceCandidateKind checkInferenceCandidate(
14431501
14441502 if (auto otherAssoc = other->getAs <DependentMemberType>()) {
14451503 if (otherAssoc->getBase ()->isEqual (selfTy)) {
1446- DependentMemberType *otherDMT;
1447- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1448-
1449- otherDMT = otherAssoc;
1450-
1451- } else {
1452-
1453- otherDMT = DependentMemberType::get (dmt->getBase (),
1504+ auto *otherDMT = DependentMemberType::get (dmt->getBase (),
14541505 otherAssoc->getAssocType ());
14551506
1456- }
1457-
14581507 result->second = result->second .transform ([&](Type t) -> Type{
14591508 if (t->isEqual (dmt))
14601509 return otherDMT;
@@ -1465,6 +1514,8 @@ static InferenceCandidateKind checkInferenceCandidate(
14651514 return InferenceCandidateKind::Good;
14661515 }
14671516 }
1517+
1518+ }
14681519 break ;
14691520 }
14701521 }
@@ -1609,8 +1660,7 @@ AssociatedTypeInference::getPotentialTypeWitnessesFromRequirement(
16091660 // itself involve unresolved type witnesses.
16101661 if (selfTy) {
16111662 // Handle Self.X := Self.X and Self.X := G<Self.X>.
1612- switch (checkInferenceCandidate (&result, conformance,
1613- witness->getDeclContext (), selfTy)) {
1663+ switch (checkInferenceCandidate (&result, conformance, witness, selfTy)) {
16141664 case InferenceCandidateKind::Good:
16151665 // The "good" case is something like `Self.X := Self.Y`.
16161666 break ;
@@ -1864,37 +1914,27 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
18641914 auto proto = conformance->getProtocol ();
18651915 auto selfTy = proto->getSelfInterfaceType ();
18661916
1867- // Get the reduced type of the witness. This rules our certain tautological
1868- // inferences below.
1869- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1870- if (auto genericSig = witness->getInnermostDeclContext ()
1871- ->getGenericSignatureOfContext ()) {
1872- type = genericSig.getReducedType (type);
1873- type = genericSig->getSugaredType (type);
1874- }
1875- }
1876-
1877- // Remap associated types that reference other protocols into this
1878- // protocol.
1879- type = type.transformRec ([proto](TypeBase *type) -> llvm::Optional<Type> {
1880- if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1881- if (depMemTy->getAssocType () &&
1882- depMemTy->getAssocType ()->getProtocol () != proto) {
1883- if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1884- auto origProto = depMemTy->getAssocType ()->getProtocol ();
1885- if (proto->inheritsFrom (origProto))
1886- return Type (DependentMemberType::get (depMemTy->getBase (),
1887- assocType));
1917+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1918+ // Remap associated types that reference other protocols into this
1919+ // protocol.
1920+ auto resultType = Type (type).transformRec ([proto](TypeBase *type)
1921+ -> llvm::Optional<Type> {
1922+ if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1923+ if (depMemTy->getAssocType () &&
1924+ depMemTy->getAssocType ()->getProtocol () != proto) {
1925+ if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1926+ auto origProto = depMemTy->getAssocType ()->getProtocol ();
1927+ if (proto->inheritsFrom (origProto))
1928+ return Type (DependentMemberType::get (depMemTy->getBase (),
1929+ assocType));
1930+ }
18881931 }
18891932 }
1890- }
18911933
1892- return llvm::None;
1893- });
1894-
1895- if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1896- auto resultType = type.subst (QueryTypeSubstitutionMap{substitutions},
1897- LookUpConformanceInModule (module ));
1934+ return llvm::None;
1935+ });
1936+ resultType = resultType.subst (QueryTypeSubstitutionMap{substitutions},
1937+ LookUpConformanceInModule (module ));
18981938 if (!resultType->hasError ()) return resultType;
18991939
19001940 // Map error types with original types *back* to the original, dependent type.
@@ -1916,9 +1956,28 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
19161956 if (!rootParam->isEqual (selfTy))
19171957 return type;
19181958
1959+ // Remap associated types that reference other protocols into this
1960+ // protocol.
1961+ auto substType = Type (type).transformRec ([proto](TypeBase *type)
1962+ -> llvm::Optional<Type> {
1963+ if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1964+ if (depMemTy->getAssocType () &&
1965+ depMemTy->getAssocType ()->getProtocol () != proto) {
1966+ if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1967+ auto origProto = depMemTy->getAssocType ()->getProtocol ();
1968+ if (proto->inheritsFrom (origProto))
1969+ return Type (DependentMemberType::get (depMemTy->getBase (),
1970+ assocType));
1971+ }
1972+ }
1973+ }
1974+
1975+ return llvm::None;
1976+ });
1977+
19191978 // Replace Self with the concrete conforming type.
1920- auto substType = Type (type) .subst (QueryTypeSubstitutionMap{substitutions},
1921- LookUpConformanceInModule (module ));
1979+ substType = substType .subst (QueryTypeSubstitutionMap{substitutions},
1980+ LookUpConformanceInModule (module ));
19221981
19231982 // If we don't have enough type witnesses, leave it abstract.
19241983 if (substType->hasError ())
0 commit comments