@@ -338,6 +338,8 @@ class SyntacticElementConstraintGenerator
338338 hadError = true ;
339339 return ;
340340 }
341+
342+ caseItem->setPattern (pattern, /* resolved=*/ true );
341343 }
342344
343345 // Let's generate constraints for pattern + where clause.
@@ -780,8 +782,6 @@ class SyntacticElementConstraintGenerator
780782 }
781783 }
782784
783- bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
784-
785785 auto *caseLoc = cs.getConstraintLocator (
786786 locator, LocatorPathElt::SyntacticElement (caseStmt));
787787
@@ -805,10 +805,8 @@ class SyntacticElementConstraintGenerator
805805 locator->castLastElementTo <LocatorPathElt::SyntacticElement>()
806806 .asStmt ());
807807
808- for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
809- auto parentVar = caseBodyVar->getParentVarDecl ();
810- assert (parentVar && " Case body variables always have parents" );
811- cs.setType (caseBodyVar, cs.getType (parentVar));
808+ if (recordInferredSwitchCasePatternVars (caseStmt)) {
809+ hadError = true ;
812810 }
813811 }
814812
@@ -935,6 +933,75 @@ class SyntacticElementConstraintGenerator
935933 locator->getLastElementAs <LocatorPathElt::SyntacticElement>();
936934 return parentElt ? parentElt->getElement ().isStmt (kind) : false ;
937935 }
936+
937+ bool recordInferredSwitchCasePatternVars (CaseStmt *caseStmt) {
938+ llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2 >, 4 > patternVars;
939+
940+ auto recordVar = [&](VarDecl *var) {
941+ if (!var->hasName ())
942+ return ;
943+ patternVars[var->getName ()].push_back (var);
944+ };
945+
946+ for (auto &caseItem : caseStmt->getMutableCaseLabelItems ()) {
947+ assert (caseItem.isPatternResolved ());
948+
949+ auto *pattern = caseItem.getPattern ();
950+ pattern->forEachVariable ([&](VarDecl *var) { recordVar (var); });
951+ }
952+
953+ for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
954+ if (!bodyVar->hasName ())
955+ continue ;
956+
957+ const auto &variants = patternVars[bodyVar->getName ()];
958+
959+ auto getType = [&](VarDecl *var) {
960+ auto type = cs.simplifyType (cs.getType (var));
961+ assert (!type->hasTypeVariable ());
962+ return type;
963+ };
964+
965+ switch (variants.size ()) {
966+ case 0 :
967+ break ;
968+
969+ case 1 :
970+ // If there is only one choice here, let's use it directly.
971+ cs.setType (bodyVar, getType (variants.front ()));
972+ break ;
973+
974+ default : {
975+ // If there are multiple choices it could only mean multiple
976+ // patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
977+ Type joinType = getType (variants.front ());
978+
979+ SmallVector<VarDecl *, 2 > conflicts;
980+ for (auto *var : llvm::drop_begin (variants)) {
981+ auto varType = getType (var);
982+ // Type mismatch between different patterns.
983+ if (!joinType->isEqual (varType))
984+ conflicts.push_back (var);
985+ }
986+
987+ if (!conflicts.empty ()) {
988+ if (!cs.shouldAttemptFixes ())
989+ return true ;
990+
991+ // dfdf
992+ auto *locator = cs.getConstraintLocator (bodyVar);
993+ if (cs.recordFix (RenameConflictingPatternVariables::create (
994+ cs, joinType, conflicts, locator)))
995+ return true ;
996+ }
997+
998+ cs.setType (bodyVar, joinType);
999+ }
1000+ }
1001+ }
1002+
1003+ return false ;
1004+ }
9381005};
9391006}
9401007
@@ -1342,6 +1409,8 @@ class SyntacticElementSolutionApplication
13421409 }
13431410 }
13441411
1412+ bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
1413+
13451414 for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
13461415 assert (expected->hasName ());
13471416 auto prev = expected->getParentVarDecl ();
0 commit comments