@@ -31,6 +31,13 @@ using namespace swift;
3131using namespace constraints ;
3232using namespace inference ;
3333
34+ void ConstraintGraphNode::initBindingSet () {
35+ ASSERT (!hasBindingSet ());
36+ ASSERT (forRepresentativeVar ());
37+
38+ Set.emplace (CG.getConstraintSystem (), TypeVar, Potential);
39+ }
40+
3441// / Check whether there exists a type that could be implicitly converted
3542// / to a given type i.e. is the given type is Double or Optional<..> this
3643// / function is going to return true because CGFloat could be converted
@@ -278,8 +285,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
278285 return false ;
279286}
280287
281- void BindingSet::inferTransitiveProtocolRequirements (
282- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
288+ void BindingSet::inferTransitiveProtocolRequirements () {
283289 if (TransitiveProtocols)
284290 return ;
285291
@@ -314,13 +320,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
314320 do {
315321 auto *currentVar = workList.back ().second ;
316322
317- auto cachedBindings = inferredBindings. find (currentVar) ;
318- if (cachedBindings == inferredBindings. end ()) {
323+ auto &node = CS. getConstraintGraph ()[currentVar] ;
324+ if (!node. hasBindingSet ()) {
319325 workList.pop_back ();
320326 continue ;
321327 }
322328
323- auto &bindings = cachedBindings-> getSecond ();
329+ auto &bindings = node. getBindingSet ();
324330
325331 // If current variable already has transitive protocol
326332 // conformances inferred, there is no need to look deeper
@@ -352,11 +358,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
352358 if (!equivalenceClass.insert (typeVar))
353359 continue ;
354360
355- auto bindingSet = inferredBindings. find (typeVar) ;
356- if (bindingSet == inferredBindings. end ())
361+ auto &node = CS. getConstraintGraph ()[typeVar] ;
362+ if (!node. hasBindingSet ())
357363 continue ;
358364
359- auto &equivalences = bindingSet-> getSecond ().Info .EquivalentTo ;
365+ auto &equivalences = node. getBindingSet ().Info .EquivalentTo ;
360366 for (const auto &eqVar : equivalences) {
361367 workList.push_back (eqVar.first );
362368 }
@@ -367,11 +373,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
367373 if (memberVar == currentVar)
368374 continue ;
369375
370- auto eqBindings = inferredBindings. find (memberVar) ;
371- if (eqBindings == inferredBindings. end ())
376+ auto &node = CS. getConstraintGraph ()[memberVar] ;
377+ if (!node. hasBindingSet ())
372378 continue ;
373379
374- const auto &bindings = eqBindings-> getSecond ();
380+ const auto &bindings = node. getBindingSet ();
375381
376382 llvm::SmallPtrSet<Constraint *, 2 > placeholder;
377383 // Add any direct protocols from members of the
@@ -423,9 +429,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
423429 // Propagate inferred protocols to all of the members of the
424430 // equivalence class.
425431 for (const auto &equivalence : bindings.Info .EquivalentTo ) {
426- auto eqBindings = inferredBindings. find ( equivalence.first ) ;
427- if (eqBindings != inferredBindings. end ()) {
428- auto &bindings = eqBindings-> getSecond ();
432+ auto &node = CS. getConstraintGraph ()[ equivalence.first ] ;
433+ if (node. hasBindingSet ()) {
434+ auto &bindings = node. getBindingSet ();
429435 bindings.TransitiveProtocols .emplace (protocolsForEquivalence.begin (),
430436 protocolsForEquivalence.end ());
431437 }
@@ -438,9 +444,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
438444 } while (!workList.empty ());
439445}
440446
441- void BindingSet::inferTransitiveBindings (
442- const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
443- &inferredBindings) {
447+ void BindingSet::inferTransitiveBindings () {
444448 using BindingKind = AllowedBindingKind;
445449
446450 // If the current type variable represents a key path root type
@@ -450,9 +454,9 @@ void BindingSet::inferTransitiveBindings(
450454 auto *locator = TypeVar->getImpl ().getLocator ();
451455 if (auto *keyPathTy =
452456 CS.getType (locator->getAnchor ())->getAs <TypeVariableType>()) {
453- auto keyPathBindings = inferredBindings. find (keyPathTy) ;
454- if (keyPathBindings != inferredBindings. end ()) {
455- auto &bindings = keyPathBindings-> getSecond ();
457+ auto &node = CS. getConstraintGraph ()[keyPathTy] ;
458+ if (node. hasBindingSet ()) {
459+ auto &bindings = node. getBindingSet ();
456460
457461 for (auto &binding : bindings.Bindings ) {
458462 auto bindingTy = binding.BindingType ->lookThroughAllOptionalTypes ();
@@ -476,9 +480,9 @@ void BindingSet::inferTransitiveBindings(
476480 // transitively used because conversions between generic arguments
477481 // are not allowed.
478482 if (auto *contextualRootVar = inferredRootTy->getAs <TypeVariableType>()) {
479- auto rootBindings = inferredBindings. find (contextualRootVar) ;
480- if (rootBindings != inferredBindings. end ()) {
481- auto &bindings = rootBindings-> getSecond ();
483+ auto &node = CS. getConstraintGraph ()[contextualRootVar] ;
484+ if (node. hasBindingSet ()) {
485+ auto &bindings = node. getBindingSet ();
482486
483487 // Don't infer if root is not yet fully resolved.
484488 if (bindings.isDelayed ())
@@ -507,11 +511,11 @@ void BindingSet::inferTransitiveBindings(
507511 }
508512
509513 for (const auto &entry : Info.SupertypeOf ) {
510- auto relatedBindings = inferredBindings. find ( entry.first ) ;
511- if (relatedBindings == inferredBindings. end ())
514+ auto &node = CS. getConstraintGraph ()[ entry.first ] ;
515+ if (!node. hasBindingSet ())
512516 continue ;
513517
514- auto &bindings = relatedBindings-> getSecond ();
518+ auto &bindings = node. getBindingSet ();
515519
516520 // FIXME: This is a workaround necessary because solver doesn't filter
517521 // bindings based on protocol requirements placed on a type variable.
@@ -610,9 +614,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
610614 return keyPathTy;
611615}
612616
613- bool BindingSet::finalize (
614- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
615- inferTransitiveBindings (inferredBindings );
617+ bool BindingSet::finalize (bool transitive) {
618+ if (transitive)
619+ inferTransitiveBindings ();
616620
617621 determineLiteralCoverage ();
618622
@@ -628,8 +632,8 @@ bool BindingSet::finalize(
628632 // func foo<T: P>(_: T) {}
629633 // foo(.bar) <- `.bar` should be a static member of `P`.
630634 // \endcode
631- if (!hasViableBindings ()) {
632- inferTransitiveProtocolRequirements (inferredBindings );
635+ if (transitive && !hasViableBindings ()) {
636+ inferTransitiveProtocolRequirements ();
633637
634638 if (TransitiveProtocols.has_value ()) {
635639 for (auto *constraint : *TransitiveProtocols) {
@@ -979,14 +983,14 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
979983std::optional<BindingSet> ConstraintSystem::determineBestBindings (
980984 llvm::function_ref<void (const BindingSet &)> onCandidate) {
981985 // Look for potential type variable bindings.
982- std::optional<BindingSet> bestBindings;
983- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
986+ BindingSet *bestBindings = nullptr ;
984987
985988 // First, let's collect all of the possible bindings.
986989 for (auto *typeVar : getTypeVariables ()) {
987- if (!typeVar->getImpl ().hasRepresentativeOrFixed ()) {
988- cache.insert ({typeVar, getBindingsFor (typeVar, /* finalize=*/ false )});
989- }
990+ auto &node = CG[typeVar];
991+ node.resetBindingSet ();
992+ if (!typeVar->getImpl ().hasRepresentativeOrFixed ())
993+ node.initBindingSet ();
990994 }
991995
992996 // Determine whether given type variable with its set of bindings is
@@ -1023,11 +1027,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10231027 // Now let's see if we could infer something for related type
10241028 // variables based on other bindings.
10251029 for (auto *typeVar : getTypeVariables ()) {
1026- auto cachedBindings = cache. find ( typeVar) ;
1027- if (cachedBindings == cache. end ())
1030+ auto &node = CG[ typeVar] ;
1031+ if (!node. hasBindingSet ())
10281032 continue ;
10291033
1030- auto &bindings = cachedBindings->getSecond ();
1034+ auto &bindings = node.getBindingSet ();
1035+
10311036 // Before attempting to infer transitive bindings let's check
10321037 // whether there are any viable "direct" bindings associated with
10331038 // current type variable, if there are none - it means that this type
@@ -1040,7 +1045,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10401045 // produce a default type.
10411046 bool isViable = isViableForRanking (bindings);
10421047
1043- if (!bindings.finalize (cache ))
1048+ if (!bindings.finalize (true ))
10441049 continue ;
10451050
10461051 if (!bindings || !isViable)
@@ -1051,10 +1056,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10511056 // If these are the first bindings, or they are better than what
10521057 // we saw before, use them instead.
10531058 if (!bestBindings || bindings < *bestBindings)
1054- bestBindings. emplace ( bindings) ;
1059+ bestBindings = & bindings;
10551060 }
10561061
1057- return bestBindings;
1062+ if (!bestBindings)
1063+ return std::nullopt ;
1064+
1065+ return std::optional (*bestBindings);
10581066}
10591067
10601068// / Find the set of type variables that are inferable from the given type.
@@ -1435,18 +1443,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
14351443 return true ;
14361444}
14371445
1438- BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar,
1439- bool finalize) {
1446+ BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar) {
14401447 assert (typeVar->getImpl ().getRepresentative (nullptr ) == typeVar &&
14411448 " not a representative" );
14421449 assert (!typeVar->getImpl ().getFixedType (nullptr ) && " has a fixed type" );
14431450
1444- BindingSet bindings (*this , typeVar, CG[typeVar].getCurrentBindings ());
1445-
1446- if (finalize) {
1447- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
1448- bindings.finalize (cache);
1449- }
1451+ BindingSet bindings (*this , typeVar, CG[typeVar].getPotentialBindings ());
1452+ bindings.finalize (false );
14501453
14511454 return bindings;
14521455}
0 commit comments