@@ -1292,8 +1292,12 @@ class SyntacticElementConstraintGenerator
12921292
12931293 // First check to make sure the ThenStmt is in a valid position.
12941294 SmallVector<ThenStmt *, 4 > validThenStmts;
1295- if (auto SVE = context.getAsSingleValueStmtExpr ())
1295+ if (auto SVE = context.getAsSingleValueStmtExpr ()) {
12961296 (void )SVE.get ()->getThenStmts (validThenStmts);
1297+ if (SVE.get ()->getStmtKind () == SingleValueStmtExpr::Kind::For) {
1298+ contextInfo = std::nullopt ;
1299+ }
1300+ }
12971301
12981302 if (!llvm::is_contained (validThenStmts, thenStmt)) {
12991303 auto *thenLoc = cs.getConstraintLocator (thenStmt);
@@ -1488,8 +1492,37 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
14881492 auto &ctx = getASTContext ();
14891493
14901494 auto *loc = getConstraintLocator (E);
1491- Type resultTy = createTypeVariable (loc, /* options*/ 0 );
1492- setType (E, resultTy);
1495+ Type resultType = createTypeVariable (loc, /* options*/ 0 );
1496+ setType (E, resultType);
1497+
1498+ if (E->getStmtKind () == SingleValueStmtExpr::Kind::For) {
1499+ auto *rrcProtocol =
1500+ ctx.getProtocol (KnownProtocolKind::RangeReplaceableCollection);
1501+ auto *sequenceProtocol = ctx.getProtocol (KnownProtocolKind::Sequence);
1502+
1503+ addConstraint (ConstraintKind::ConformsTo, resultType,
1504+ rrcProtocol->getDeclaredInterfaceType (), loc);
1505+ Type elementTypeVar = createTypeVariable (loc, /* options*/ 0 );
1506+ Type elementType = DependentMemberType::get (
1507+ resultType, sequenceProtocol->getAssociatedType (ctx.Id_Element ));
1508+
1509+ addConstraint (ConstraintKind::Bind, elementTypeVar, elementType, loc);
1510+ addConstraint (ConstraintKind::Defaultable, resultType,
1511+ ArraySliceType::get (elementTypeVar), loc);
1512+
1513+ auto *binding = E->getForExpressionPreamble ()->ForAccumulatorBinding ;
1514+
1515+ auto *initializer = binding->getInit (0 );
1516+ auto target = SyntacticElementTarget::forInitialization (initializer, Type (),
1517+ binding, 0 , false );
1518+ setTargetFor ({binding, 0 }, target);
1519+
1520+ if (generateConstraints (target)) {
1521+ return true ;
1522+ }
1523+
1524+ addConstraint (ConstraintKind::Bind, getType (initializer), resultType, loc);
1525+ }
14931526
14941527 // Propagate the implied result kind from the if/switch expression itself
14951528 // into the branches.
@@ -1513,21 +1546,24 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15131546 auto *loc = getConstraintLocator (
15141547 E, {LocatorPathElt::SingleValueStmtResult (idx), ctpElt});
15151548
1516- ContextualTypeInfo info (resultTy , CTP_SingleValueStmtBranch, loc);
1549+ ContextualTypeInfo info (resultType , CTP_SingleValueStmtBranch, loc);
15171550 setContextualInfo (result, info);
15181551 }
15191552
15201553 TypeJoinExpr *join = nullptr ;
1521- if (branches.empty ()) {
1522- // If we only have statement branches, the expression is typed as Void. This
1523- // should only be the case for 'if' and 'switch' statements that must be
1524- // expressions that have branches that all end in a throw, and we'll warn
1525- // that we've inferred Void.
1526- addConstraint (ConstraintKind::Bind, resultTy, ctx.getVoidType (), loc);
1527- } else {
1528- // Otherwise, we join the result types for each of the branches.
1529- join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr (
1530- ctx, resultTy, E, AllocationArena::ConstraintSolver);
1554+
1555+ if (E->getStmtKind () != SingleValueStmtExpr::Kind::For) {
1556+ if (branches.empty ()) {
1557+ // If we only have statement branches, the expression is typed as Void.
1558+ // This should only be the case for 'if' and 'switch' statements that must
1559+ // be expressions that have branches that all end in a throw, and we'll
1560+ // warn that we've inferred Void.
1561+ addConstraint (ConstraintKind::Bind, resultType, ctx.getVoidType (), loc);
1562+ } else {
1563+ // Otherwise, we join the result types for each of the branches.
1564+ join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr (
1565+ ctx, resultType, E, AllocationArena::ConstraintSolver);
1566+ }
15311567 }
15321568
15331569 // If this is an implied return in a closure, we need to account for the fact
@@ -1568,11 +1604,11 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
15681604 if (auto *closureTy = getClosureTypeIfAvailable (CE)) {
15691605 auto closureResultTy = closureTy->getResult ();
15701606 auto *bindToClosure = Constraint::create (
1571- *this , ConstraintKind::Bind, resultTy , closureResultTy, loc);
1607+ *this , ConstraintKind::Bind, resultType , closureResultTy, loc);
15721608 bindToClosure->setFavored ();
15731609
1574- auto *bindToVoid = Constraint::create (* this , ConstraintKind::Bind,
1575- resultTy , ctx.getVoidType (), loc);
1610+ auto *bindToVoid = Constraint::create (
1611+ * this , ConstraintKind::Bind, resultType , ctx.getVoidType (), loc);
15761612
15771613 addDisjunctionConstraint ({bindToClosure, bindToVoid}, loc);
15781614 }
@@ -2221,7 +2257,9 @@ class SyntacticElementSolutionApplication
22212257 // not the branch result type. This is necessary as there may be
22222258 // an additional conversion required for the branch.
22232259 auto target = solution.getTargetFor (thenStmt->getResult ());
2224- target->setExprConversionType (ty);
2260+ if (SVE.get ()->getStmtKind () != SingleValueStmtExpr::Kind::For) {
2261+ target->setExprConversionType (ty);
2262+ }
22252263
22262264 auto *resultExpr = thenStmt->getResult ();
22272265 if (auto newResultTarget = rewriter.rewriteTarget (*target))
@@ -2663,6 +2701,18 @@ bool ConstraintSystem::applySolutionToSingleValueStmt(
26632701 if (!stmt || application.hadError )
26642702 return true ;
26652703
2704+ if (SVE->getStmtKind () == SingleValueStmtExpr::Kind::For) {
2705+ auto *binding = SVE->getForExpressionPreamble ()->ForAccumulatorBinding ;
2706+ auto target = getTargetFor ({binding, 0 }).value ();
2707+
2708+ auto newTarget = rewriter.rewriteTarget (target);
2709+ if (!newTarget) {
2710+ return true ;
2711+ }
2712+
2713+ binding->setInit (0 , newTarget->getAsExpr ());
2714+ }
2715+
26662716 SVE->setStmt (stmt);
26672717 return false ;
26682718}
0 commit comments