@@ -918,13 +918,13 @@ std::optional<BraceStmt *>
918918TypeChecker::applyResultBuilderBodyTransform (FuncDecl *func, Type builderType) {
919919 // First look for any return statements, and bail if we have any.
920920 auto &ctx = func->getASTContext ();
921- if (evaluateOrDefault (ctx.evaluator , BraceHasReturnRequest{func->getBody ()},
922- false )) {
921+
922+ SmallVector<ReturnStmt *> returnStmts;
923+ func->getExplicitReturnStmts (returnStmts);
924+
925+ if (!returnStmts.empty ()) {
923926 // One or more explicit 'return' statements were encountered, which
924927 // disables the result builder transform. Warn when we do this.
925- auto returnStmts = findReturnStatements (func);
926- assert (!returnStmts.empty ());
927-
928928 ctx.Diags .diagnose (
929929 returnStmts.front ()->getReturnLoc (),
930930 diag::result_builder_disabled_by_return_warn, builderType);
@@ -1126,8 +1126,7 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType,
11261126 // not apply the result builder transform if it contained an explicit return.
11271127 // To maintain source compatibility, we still need to check for HasReturnStmt.
11281128 // https://github.com/apple/swift/issues/64332.
1129- if (evaluateOrDefault (getASTContext ().evaluator ,
1130- BraceHasReturnRequest{fn.getBody ()}, false )) {
1129+ if (fn.bodyHasExplicitReturnStmt ()) {
11311130 // Diagnostic mode means that solver couldn't reach any viable
11321131 // solution, so let's diagnose presence of a `return` statement
11331132 // in the closure body.
@@ -1235,49 +1234,84 @@ void ConstraintSystem::removeResultBuilderTransform(AnyFunctionRef fn) {
12351234 ASSERT (erased);
12361235}
12371236
1238- namespace {
1239- class ReturnStmtFinder : public ASTWalker {
1240- std::vector<ReturnStmt *> ReturnStmts;
1237+ // / Walks the given brace statement and calls the given function reference on
1238+ // / every occurrence of an explicit `return` statement.
1239+ // /
1240+ // / \param callback A function reference that takes a `return` statement and
1241+ // / returns a boolean value indicating whether to abort the walk.
1242+ // /
1243+ // / \returns `true` if the walk was aborted, `false` otherwise.
1244+ static bool walkExplicitReturnStmts (const BraceStmt *BS,
1245+ function_ref<bool (ReturnStmt *)> callback) {
1246+ class Walker : public ASTWalker {
1247+ function_ref<bool (ReturnStmt *)> callback;
1248+
1249+ public:
1250+ Walker (decltype (Walker::callback) callback) : callback(callback) {}
1251+
1252+ MacroWalking getMacroWalkingBehavior () const override {
1253+ return MacroWalking::Arguments;
1254+ }
12411255
1242- public:
1243- static std::vector<ReturnStmt *> find (const BraceStmt *BS) {
1244- ReturnStmtFinder finder;
1245- const_cast <BraceStmt *>(BS)->walk (finder);
1246- return std::move (finder.ReturnStmts );
1247- }
1256+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
1257+ return Action::SkipNode (E);
1258+ }
12481259
1249- MacroWalking getMacroWalkingBehavior () const override {
1250- return MacroWalking::Arguments;
1251- }
1260+ PreWalkResult<Stmt *> walkToStmtPre (Stmt *S) override {
1261+ if (S->isImplicit ()) {
1262+ return Action::SkipNode (S);
1263+ }
12521264
1253- PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
1254- return Action::SkipNode (E);
1255- }
1265+ auto *returnStmt = dyn_cast<ReturnStmt>(S);
1266+ if (!returnStmt) {
1267+ return Action::Continue (S);
1268+ }
12561269
1257- PreWalkResult<Stmt *> walkToStmtPre (Stmt *S) override {
1258- // If we see a return statement, note it..
1259- auto *returnStmt = dyn_cast<ReturnStmt>(S);
1260- if (!returnStmt || returnStmt->isImplicit ())
1261- return Action::Continue (S);
1270+ if (callback (returnStmt)) {
1271+ return Action::Stop ();
1272+ }
12621273
1263- ReturnStmts. push_back (returnStmt);
1264- return Action::SkipNode (S);
1265- }
1274+ // Skip children & post walk and continue.
1275+ return Action::SkipNode (S);
1276+ }
12661277
1267- // / Ignore patterns.
1268- PreWalkResult<Pattern *> walkToPatternPre (Pattern *pat) override {
1269- return Action::SkipNode (pat);
1278+ // / Ignore patterns.
1279+ PreWalkResult<Pattern *> walkToPatternPre (Pattern *pat) override {
1280+ return Action::SkipNode (pat);
1281+ }
1282+ };
1283+
1284+ Walker walker (callback);
1285+
1286+ return const_cast <BraceStmt *>(BS)->walk (walker) == nullptr ;
1287+ }
1288+
1289+ bool BraceHasExplicitReturnStmtRequest::evaluate (Evaluator &evaluator,
1290+ const BraceStmt *BS) const {
1291+ return walkExplicitReturnStmts (BS, [](ReturnStmt *) { return true ; });
1292+ }
1293+
1294+ bool AnyFunctionRef::bodyHasExplicitReturnStmt () const {
1295+ auto *body = getBody ();
1296+ if (!body) {
1297+ return false ;
12701298 }
1271- };
1272- } // end anonymous namespace
12731299
1274- bool BraceHasReturnRequest::evaluate (Evaluator &evaluator,
1275- const BraceStmt *BS) const {
1276- return ! ReturnStmtFinder::find (BS). empty ( );
1300+ auto &ctx = getAsDeclContext ()-> getASTContext ();
1301+ return evaluateOrDefault (ctx. evaluator ,
1302+ BraceHasExplicitReturnStmtRequest{body}, false );
12771303}
12781304
1279- std::vector<ReturnStmt *> TypeChecker::findReturnStatements (AnyFunctionRef fn) {
1280- return ReturnStmtFinder::find (fn.getBody ());
1305+ void AnyFunctionRef::getExplicitReturnStmts (
1306+ SmallVectorImpl<ReturnStmt *> &results) const {
1307+ if (!bodyHasExplicitReturnStmt ()) {
1308+ return ;
1309+ }
1310+
1311+ walkExplicitReturnStmts (getBody (), [&results](ReturnStmt *RS) {
1312+ results.push_back (RS);
1313+ return false ;
1314+ });
12811315}
12821316
12831317ResultBuilderOpSupport TypeChecker::checkBuilderOpSupport (
0 commit comments