@@ -1066,6 +1066,16 @@ class AssociatedTypeInference {
10661066 bool isBetterSolution (const InferredTypeWitnessesSolution &first,
10671067 const InferredTypeWitnessesSolution &second);
10681068
1069+ // / Find the best solution.
1070+ // /
1071+ // / \param solutions All of the solutions to consider. On success,
1072+ // / this will contain only the best solution.
1073+ // /
1074+ // / \returns \c false if there was a single best solution,
1075+ // / \c true if no single best solution exists.
1076+ bool findBestSolution (
1077+ SmallVectorImpl<InferredTypeWitnessesSolution> &solutions);
1078+
10691079 // / Emit a diagnostic for the case where there are no solutions at all
10701080 // / to consider.
10711081 // /
@@ -1902,19 +1912,20 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
19021912 else
19031913 continue ;
19041914
1915+ if (result.empty ()) {
1916+ // If we found at least one default candidate, we must allow for the
1917+ // possibility that no default is chosen by adding a tautological witness
1918+ // to our disjunction.
1919+ result.push_back (InferredAssociatedTypesByWitness ());
1920+ }
1921+
19051922 // Add this result.
19061923 InferredAssociatedTypesByWitness inferred;
19071924 inferred.Witness = typeDecl;
19081925 inferred.Inferred .push_back ({assocType, witnessType});
19091926 result.push_back (std::move (inferred));
19101927 }
19111928
1912- if (!result.empty ()) {
1913- // If we found at least one default candidate, we must allow for the
1914- // possibility that no default is chosen by adding a tautological witness
1915- // to our disjunction.
1916- result.push_back (InferredAssociatedTypesByWitness ());
1917- }
19181929 return result;
19191930}
19201931
@@ -3130,6 +3141,35 @@ void AssociatedTypeInference::findSolutionsRec(
31303141 known->first = replaced;
31313142 }
31323143
3144+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3145+ // Check whether our current solution matches the given solution.
3146+ auto matchesSolution =
3147+ [&](const InferredTypeWitnessesSolution &solution) {
3148+ for (const auto &existingTypeWitness : solution.TypeWitnesses ) {
3149+ auto typeWitness = typeWitnesses.begin (existingTypeWitness.first );
3150+ if (!typeWitness->first ->isEqual (existingTypeWitness.second .first ))
3151+ return false ;
3152+ }
3153+
3154+ return true ;
3155+ };
3156+
3157+ // If we've seen this solution already, bail out; there's no point in
3158+ // checking further.
3159+ if (llvm::any_of (solutions, matchesSolution)) {
3160+ LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
3161+ << " + Duplicate valid solution found\n " ;);
3162+ ++NumDuplicateSolutionStates;
3163+ return ;
3164+ }
3165+ if (llvm::any_of (nonViableSolutions, matchesSolution)) {
3166+ LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
3167+ << " + Duplicate invalid solution found\n " ;);
3168+ ++NumDuplicateSolutionStates;
3169+ return ;
3170+ }
3171+ }
3172+
31333173 // / Check the current set of type witnesses.
31343174 bool invalid = checkCurrentTypeWitnesses (valueWitnesses);
31353175
@@ -3156,6 +3196,8 @@ void AssociatedTypeInference::findSolutionsRec(
31563196 = numValueWitnessesInProtocolExtensions;
31573197
31583198 // We fold away non-viable solutions that have the same type witnesses.
3199+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3200+
31593201 if (invalid) {
31603202 if (llvm::find (nonViableSolutions, solution) != nonViableSolutions.end ()) {
31613203 LLVM_DEBUG (llvm::dbgs () << std::string (valueWitnesses.size (), ' +' )
@@ -3168,6 +3210,22 @@ void AssociatedTypeInference::findSolutionsRec(
31683210 return ;
31693211 }
31703212
3213+ }
3214+
3215+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
3216+
3217+ auto &solutionList = invalid ? nonViableSolutions : solutions;
3218+ solutionList.push_back (solution);
3219+
3220+ // If this solution was clearly better than the previous best solution,
3221+ // swap them.
3222+ if (solutionList.back ().NumValueWitnessesInProtocolExtensions
3223+ < solutionList.front ().NumValueWitnessesInProtocolExtensions ) {
3224+ std::swap (solutionList.front (), solutionList.back ());
3225+ }
3226+
3227+ } else {
3228+
31713229 // For valid solutions, we want to find the best solution if one exists.
31723230 // We maintain the invariant that no viable solution is clearly worse than
31733231 // any other viable solution. If multiple viable solutions remain after
@@ -3197,6 +3255,8 @@ void AssociatedTypeInference::findSolutionsRec(
31973255 });
31983256
31993257 solutions.push_back (std::move (solution));
3258+
3259+ }
32003260 return ;
32013261 }
32023262
@@ -3565,6 +3625,58 @@ bool AssociatedTypeInference::isBetterSolution(
35653625 return firstBetter;
35663626}
35673627
3628+ bool AssociatedTypeInference::findBestSolution (
3629+ SmallVectorImpl<InferredTypeWitnessesSolution> &solutions) {
3630+ if (solutions.empty ()) return true ;
3631+ if (solutions.size () == 1 ) return false ;
3632+
3633+ // The solution at the front has the smallest number of value witnesses found
3634+ // in protocol extensions, by construction.
3635+ unsigned bestNumValueWitnessesInProtocolExtensions
3636+ = solutions.front ().NumValueWitnessesInProtocolExtensions ;
3637+
3638+ // Erase any solutions with more value witnesses in protocol
3639+ // extensions than the best.
3640+ solutions.erase (
3641+ std::remove_if (solutions.begin (), solutions.end (),
3642+ [&](const InferredTypeWitnessesSolution &solution) {
3643+ return solution.NumValueWitnessesInProtocolExtensions >
3644+ bestNumValueWitnessesInProtocolExtensions;
3645+ }),
3646+ solutions.end ());
3647+
3648+ // If we're down to one solution, success!
3649+ if (solutions.size () == 1 ) return false ;
3650+
3651+ // Find a solution that's at least as good as the solutions that follow it.
3652+ unsigned bestIdx = 0 ;
3653+ for (unsigned i = 1 , n = solutions.size (); i != n; ++i) {
3654+ if (isBetterSolution (solutions[i], solutions[bestIdx]))
3655+ bestIdx = i;
3656+ }
3657+
3658+ // Make sure that solution is better than any of the other solutions.
3659+ bool ambiguous = false ;
3660+ for (unsigned i = 1 , n = solutions.size (); i != n; ++i) {
3661+ if (i != bestIdx && !isBetterSolution (solutions[bestIdx], solutions[i])) {
3662+ ambiguous = true ;
3663+ break ;
3664+ }
3665+ }
3666+
3667+ // If the result was ambiguous, fail.
3668+ if (ambiguous) {
3669+ assert (solutions.size () != 1 && " should have succeeded somewhere above?" );
3670+ return true ;
3671+
3672+ }
3673+ // Keep the best solution, erasing all others.
3674+ if (bestIdx != 0 )
3675+ solutions[0 ] = std::move (solutions[bestIdx]);
3676+ solutions.erase (solutions.begin () + 1 , solutions.end ());
3677+ return false ;
3678+ }
3679+
35683680namespace {
35693681 // / A failed type witness binding.
35703682 struct FailedTypeWitness {
@@ -3971,7 +4083,9 @@ auto AssociatedTypeInference::solve()
39714083 }
39724084
39734085 // Happy case: we found exactly one unique viable solution.
3974- if (solutions.size () == 1 ) {
4086+ if (!findBestSolution (solutions)) {
4087+ assert (solutions.size () == 1 && " Not a unique best solution?" );
4088+
39754089 // Form the resulting solution.
39764090 auto &typeWitnesses = solutions.front ().TypeWitnesses ;
39774091 for (auto assocType : unresolvedAssocTypes) {
0 commit comments