@@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
81598159 return SourceRange ();
81608160}
81618161
8162- static std::optional<std::pair<CaseStmt *, Pattern *>>
8163- findParentPatternCaseStmtAndPattern (const VarDecl *inputVD) {
8164- auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
8165- // Check if inputVD is in our case body var decls if we have any. If we do,
8166- // treat its pattern as our first case label item pattern.
8167- for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray ()) {
8168- if (vd == inputVD) {
8169- return cs->getMutableCaseLabelItems ().front ().getPattern ();
8170- }
8171- }
8172-
8173- // Then check the rest of our case label items.
8174- for (auto &item : cs->getMutableCaseLabelItems ()) {
8175- if (item.getPattern ()->containsVarDecl (inputVD)) {
8176- return item.getPattern ();
8177- }
8178- }
8179-
8180- // Otherwise return false if we do not find anything.
8181- return nullptr ;
8182- };
8183-
8184- // First find our canonical var decl. This is the VarDecl corresponding to the
8185- // first case label item of the first case block in the fallthrough chain that
8186- // our case block is within. Grab the case stmt associated with that var decl
8187- // and start traveling down the fallthrough chain looking for the case
8188- // statement that the input VD belongs to by using getMatchingPattern().
8189- auto *canonicalVD = inputVD->getCanonicalVarDecl ();
8190- auto *caseStmt =
8191- dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt ());
8192- if (!caseStmt)
8193- return std::nullopt ;
8194-
8195- if (auto *p = getMatchingPattern (caseStmt))
8196- return std::make_pair (caseStmt, p);
8197-
8198- while ((caseStmt = caseStmt->getFallthroughDest ().getPtrOrNull ())) {
8199- if (auto *p = getMatchingPattern (caseStmt))
8200- return std::make_pair (caseStmt, p);
8201- }
8202-
8203- return std::nullopt ;
8204- }
8205-
82068162VarDecl *VarDecl::getCanonicalVarDecl () const {
82078163 // Any var decl without a parent var decl is canonical. This means that before
82088164 // type checking, all var decls are canonical.
@@ -8227,16 +8183,7 @@ VarDecl *VarDecl::getCanonicalVarDecl() const {
82278183}
82288184
82298185Stmt *VarDecl::getRecursiveParentPatternStmt () const {
8230- // If our parent is already a pattern stmt, just return that.
8231- if (auto *stmt = getParentPatternStmt ())
8232- return stmt;
8233-
8234- // Otherwise, see if we have a parent var decl. If we do not, then return
8235- // nullptr. Otherwise, return the case stmt that we found.
8236- auto result = findParentPatternCaseStmtAndPattern (this );
8237- if (!result.has_value ())
8238- return nullptr ;
8239- return result->first ;
8186+ return getCanonicalVarDecl ()->getParentPatternStmt ();
82408187}
82418188
82428189// / Return the Pattern involved in initializing this VarDecl. Recall that the
@@ -8256,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
82568203 }
82578204
82588205 // If this is a statement parent, dig the pattern out of it.
8259- if (auto *stmt = getParentPatternStmt ()) {
8206+ const auto *canonicalVD = getCanonicalVarDecl ();
8207+ if (auto *stmt = canonicalVD->getParentPatternStmt ()) {
82608208 if (auto *FES = dyn_cast<ForEachStmt>(stmt))
82618209 return FES->getPattern ();
82628210
82638211 if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
8264- // In a case statement, search for the pattern that contains it. This is
8265- // a bit silly, because you can't have something like "case x, y:" anyway.
8266- for (auto items : cs->getCaseLabelItems ()) {
8267- if (items.getPattern ()->containsVarDecl (this ))
8268- return items.getPattern ();
8212+ // In a case statement, search for the pattern that contains it.
8213+ auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
8214+ for (auto items : cs->getCaseLabelItems ()) {
8215+ if (items.getPattern ()->containsVarDecl (VD))
8216+ return items.getPattern ();
8217+ }
8218+ return nullptr ;
8219+ };
8220+ if (auto *P = findPattern (cs, this ))
8221+ return P;
8222+
8223+ // If it's not in the CaseStmt, check its fallthrough destination.
8224+ if (auto fallthrough = cs->getFallthroughDest ()) {
8225+ if (auto *P = findPattern (fallthrough.get (), this ))
8226+ return P;
82698227 }
8228+
8229+ // Finally, check the canonical variable, this is necessary to correctly
8230+ // handle case body vars, we just want to take the first pattern that
8231+ // declares it in that case.
8232+ if (auto *P = findPattern (cs, canonicalVD))
8233+ return P;
82708234 }
82718235
82728236 if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
@@ -8277,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
82778241 }
82788242 }
82798243
8280- // Otherwise, check if we have to walk our case stmt's var decl list to find
8281- // the pattern.
8282- if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern (this )) {
8283- return caseStmtPatternPair->second ;
8284- }
8285-
8286- // Otherwise, this is a case we do not know or understand. Return nullptr to
8287- // signal we do not have any information.
82888244 return nullptr ;
82898245}
82908246
@@ -8345,7 +8301,7 @@ bool VarDecl::isCaseBodyVariable() const {
83458301 auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt ());
83468302 if (!caseStmt)
83478303 return false ;
8348- return llvm::any_of (caseStmt->getCaseBodyVariablesOrEmptyArray (),
8304+ return llvm::any_of (caseStmt->getCaseBodyVariables (),
83498305 [&](VarDecl *vd) { return vd == this ; });
83508306}
83518307
0 commit comments