@@ -37,6 +37,9 @@ Expr *getVoidExpr(ASTContext &ctx) {
3737
3838// / Find any type variable references inside of an AST node.
3939class TypeVariableRefFinder : public ASTWalker {
40+ // / A stack of all closures the walker encountered so far.
41+ SmallVector<DeclContext *> ClosureDCs;
42+
4043 ConstraintSystem &CS;
4144 ASTNode Parent;
4245
@@ -46,9 +49,16 @@ class TypeVariableRefFinder : public ASTWalker {
4649 TypeVariableRefFinder (
4750 ConstraintSystem &cs, ASTNode parent,
4851 llvm::SmallPtrSetImpl<TypeVariableType *> &referencedVars)
49- : CS(cs), Parent(parent), ReferencedVars(referencedVars) {}
52+ : CS(cs), Parent(parent), ReferencedVars(referencedVars) {
53+ if (auto *closure = getAsExpr<ClosureExpr>(Parent))
54+ ClosureDCs.push_back (closure);
55+ }
5056
5157 std::pair<bool , Expr *> walkToExprPre (Expr *expr) override {
58+ if (auto *closure = dyn_cast<ClosureExpr>(expr)) {
59+ ClosureDCs.push_back (closure);
60+ }
61+
5262 if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
5363 auto *decl = DRE->getDecl ();
5464
@@ -81,20 +91,33 @@ class TypeVariableRefFinder : public ASTWalker {
8191 return {true , expr};
8292 }
8393
94+ Expr *walkToExprPost (Expr *expr) override {
95+ if (auto *closure = dyn_cast<ClosureExpr>(expr)) {
96+ ClosureDCs.pop_back ();
97+ }
98+ return expr;
99+ }
100+
84101 std::pair<bool , Stmt *> walkToStmtPre (Stmt *stmt) override {
85102 // Return statements have to reference outside result type
86103 // since all of them are joined by it if it's not specified
87104 // explicitly.
88105 if (isa<ReturnStmt>(stmt)) {
89106 if (auto *closure = getAsExpr<ClosureExpr>(Parent)) {
90- inferVariables (CS.getClosureType (closure)->getResult ());
107+ // Return is only viable if it belongs to a parent closure.
108+ if (currentClosureDC () == closure)
109+ inferVariables (CS.getClosureType (closure)->getResult ());
91110 }
92111 }
93112
94113 return {true , stmt};
95114 }
96115
97116private:
117+ DeclContext *currentClosureDC () const {
118+ return ClosureDCs.empty () ? nullptr : ClosureDCs.back ();
119+ }
120+
98121 void inferVariables (Type type) {
99122 type = type->getWithoutSpecifierType ();
100123 // Record the type variable itself because it has to
0 commit comments