|
16 | 16 |
|
17 | 17 | #include "swift/AST/Stmt.h" |
18 | 18 | #include "swift/AST/ASTContext.h" |
| 19 | +#include "swift/AST/ASTWalker.h" |
19 | 20 | #include "swift/AST/Decl.h" |
20 | 21 | #include "swift/AST/Expr.h" |
21 | 22 | #include "swift/AST/Pattern.h" |
@@ -155,6 +156,65 @@ BraceStmt *BraceStmt::create(ASTContext &ctx, SourceLoc lbloc, |
155 | 156 | return ::new(Buffer) BraceStmt(lbloc, elts, rbloc, implicit); |
156 | 157 | } |
157 | 158 |
|
| 159 | +ASTNode BraceStmt::findAsyncNode() { |
| 160 | + // TODO: Statements don't track their ASTContext/evaluator, so I am not making |
| 161 | + // this a request. It probably should be a request at some point. |
| 162 | + // |
| 163 | + // While we're at it, it would be very nice if this could be a const |
| 164 | + // operation, but the AST-walking is not a const operation. |
| 165 | + |
| 166 | + // A walker that looks for 'async' and 'await' expressions |
| 167 | + // that aren't nested within closures or nested declarations. |
| 168 | + class FindInnerAsync : public ASTWalker { |
| 169 | + ASTNode AsyncNode; |
| 170 | + |
| 171 | + std::pair<bool, Expr *> walkToExprPre(Expr *expr) override { |
| 172 | + // If we've found an 'await', record it and terminate the traversal. |
| 173 | + if (isa<AwaitExpr>(expr)) { |
| 174 | + AsyncNode = expr; |
| 175 | + return {false, nullptr}; |
| 176 | + } |
| 177 | + |
| 178 | + // Do not recurse into other closures. |
| 179 | + if (isa<ClosureExpr>(expr)) |
| 180 | + return {false, expr}; |
| 181 | + |
| 182 | + return {true, expr}; |
| 183 | + } |
| 184 | + |
| 185 | + bool walkToDeclPre(Decl *decl) override { |
| 186 | + // Do not walk into function or type declarations. |
| 187 | + if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) { |
| 188 | + if (patternBinding->isAsyncLet()) |
| 189 | + AsyncNode = patternBinding; |
| 190 | + |
| 191 | + return true; |
| 192 | + } |
| 193 | + |
| 194 | + return false; |
| 195 | + } |
| 196 | + |
| 197 | + std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override { |
| 198 | + if (auto forEach = dyn_cast<ForEachStmt>(stmt)) { |
| 199 | + if (forEach->getAwaitLoc().isValid()) { |
| 200 | + AsyncNode = forEach; |
| 201 | + return {false, nullptr}; |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + return {true, stmt}; |
| 206 | + } |
| 207 | + |
| 208 | + public: |
| 209 | + ASTNode getAsyncNode() { return AsyncNode; } |
| 210 | + }; |
| 211 | + |
| 212 | + FindInnerAsync asyncFinder; |
| 213 | + walk(asyncFinder); |
| 214 | + |
| 215 | + return asyncFinder.getAsyncNode(); |
| 216 | +} |
| 217 | + |
158 | 218 | SourceLoc ReturnStmt::getStartLoc() const { |
159 | 219 | if (ReturnLoc.isInvalid() && Result) |
160 | 220 | return Result->getStartLoc(); |
|
0 commit comments