@@ -338,6 +338,8 @@ class SyntacticElementConstraintGenerator
338338 hadError = true ;
339339 return ;
340340 }
341+
342+ caseItem->setPattern (pattern, /* resolved=*/ true );
341343 }
342344
343345 // Let's generate constraints for pattern + where clause.
@@ -774,8 +776,6 @@ class SyntacticElementConstraintGenerator
774776 }
775777 }
776778
777- bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
778-
779779 auto *caseLoc = cs.getConstraintLocator (
780780 locator, LocatorPathElt::SyntacticElement (caseStmt));
781781
@@ -799,10 +799,8 @@ class SyntacticElementConstraintGenerator
799799 locator->castLastElementTo <LocatorPathElt::SyntacticElement>()
800800 .asStmt ());
801801
802- for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
803- auto parentVar = caseBodyVar->getParentVarDecl ();
804- assert (parentVar && " Case body variables always have parents" );
805- cs.setType (caseBodyVar, cs.getType (parentVar));
802+ if (recordInferredSwitchCasePatternVars (caseStmt)) {
803+ hadError = true ;
806804 }
807805 }
808806
@@ -929,6 +927,75 @@ class SyntacticElementConstraintGenerator
929927 locator->getLastElementAs <LocatorPathElt::SyntacticElement>();
930928 return parentElt ? parentElt->getElement ().isStmt (kind) : false ;
931929 }
930+
931+ bool recordInferredSwitchCasePatternVars (CaseStmt *caseStmt) {
932+ llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2 >, 4 > patternVars;
933+
934+ auto recordVar = [&](VarDecl *var) {
935+ if (!var->hasName ())
936+ return ;
937+ patternVars[var->getName ()].push_back (var);
938+ };
939+
940+ for (auto &caseItem : caseStmt->getMutableCaseLabelItems ()) {
941+ assert (caseItem.isPatternResolved ());
942+
943+ auto *pattern = caseItem.getPattern ();
944+ pattern->forEachVariable ([&](VarDecl *var) { recordVar (var); });
945+ }
946+
947+ for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
948+ if (!bodyVar->hasName ())
949+ continue ;
950+
951+ const auto &variants = patternVars[bodyVar->getName ()];
952+
953+ auto getType = [&](VarDecl *var) {
954+ auto type = cs.simplifyType (cs.getType (var));
955+ assert (!type->hasTypeVariable ());
956+ return type;
957+ };
958+
959+ switch (variants.size ()) {
960+ case 0 :
961+ break ;
962+
963+ case 1 :
964+ // If there is only one choice here, let's use it directly.
965+ cs.setType (bodyVar, getType (variants.front ()));
966+ break ;
967+
968+ default : {
969+ // If there are multiple choices it could only mean multiple
970+ // patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
971+ Type joinType = getType (variants.front ());
972+
973+ SmallVector<VarDecl *, 2 > conflicts;
974+ for (auto *var : llvm::drop_begin (variants)) {
975+ auto varType = getType (var);
976+ // Type mismatch between different patterns.
977+ if (!joinType->isEqual (varType))
978+ conflicts.push_back (var);
979+ }
980+
981+ if (!conflicts.empty ()) {
982+ if (!cs.shouldAttemptFixes ())
983+ return true ;
984+
985+ // dfdf
986+ auto *locator = cs.getConstraintLocator (bodyVar);
987+ if (cs.recordFix (RenameConflictingPatternVariables::create (
988+ cs, joinType, conflicts, locator)))
989+ return true ;
990+ }
991+
992+ cs.setType (bodyVar, joinType);
993+ }
994+ }
995+ }
996+
997+ return false ;
998+ }
932999};
9331000}
9341001
@@ -1336,6 +1403,8 @@ class SyntacticElementSolutionApplication
13361403 }
13371404 }
13381405
1406+ bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
1407+
13391408 for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
13401409 assert (expected->hasName ());
13411410 auto prev = expected->getParentVarDecl ();
0 commit comments