Skip to content

Commit 2524be2

Browse files
committed
First cut of making coroutine AST type
1 parent 5f75587 commit 2524be2

32 files changed

+266
-111
lines changed

include/swift/AST/AnyFunctionRef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class AnyFunctionRef {
9696
Type getBodyResultType() const {
9797
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
9898
if (auto *FD = dyn_cast<FuncDecl>(AFD))
99-
return FD->mapTypeIntoContext(FD->getResultInterfaceType());
99+
return FD->mapTypeIntoContext(FD->getResultInterfaceTypeWithoutYields());
100100
return TupleType::getEmpty(AFD->getASTContext());
101101
}
102102
return TheFunction.get<AbstractClosureExpr *>()->getResultType();

include/swift/AST/Decl.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8333,9 +8333,15 @@ class FuncDecl : public AbstractFunctionDecl {
83338333
return FnRetType.getSourceRange();
83348334
}
83358335

8336-
/// Retrieve the result interface type of this function.
8336+
/// Retrieve the full result interface type of this function, including yields
83378337
Type getResultInterfaceType() const;
83388338

8339+
/// Same as above, but without @yields
8340+
Type getResultInterfaceTypeWithoutYields() const;
8341+
8342+
/// Same as above, but only yields
8343+
Type getYieldsInterfaceType() const;
8344+
83398345
/// isUnaryOperator - Determine whether this is a unary operator
83408346
/// implementation. This check is a syntactic rather than type-based check,
83418347
/// which looks at the number of parameters specified, in order to allow

include/swift/AST/Types.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,10 +1673,13 @@ class YieldResultType : public TypeBase {
16731673
};
16741674

16751675
BEGIN_CAN_TYPE_WRAPPER(YieldResultType, Type)
1676-
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1677-
static CanYieldResultType get(CanType type, bool InOut) {
1678-
return CanYieldResultType(YieldResultType::get(type, InOut));
1679-
}
1676+
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1677+
bool isInOut() const {
1678+
return getPointer()->isInOut();
1679+
}
1680+
static CanYieldResultType get(CanType type, bool InOut) {
1681+
return CanYieldResultType(YieldResultType::get(type, InOut));
1682+
}
16801683
END_CAN_TYPE_WRAPPER(YieldResultType, Type)
16811684

16821685
/// BuiltinType - An abstract class for all the builtin types.

include/swift/SIL/AbstractionPattern.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,10 @@ class AbstractionPattern {
15241524
/// type, return the abstraction pattern for one of its argument types.
15251525
AbstractionPattern getParameterizedProtocolArgType(unsigned i) const;
15261526

1527+
/// Given that the value being abstracted is a yield result type,
1528+
/// return the abstraction pattern for corresponding type.
1529+
AbstractionPattern getYieldResultType() const;
1530+
15271531
/// Given that the value being abstracted is a function, return the
15281532
/// abstraction pattern for its result type.
15291533
AbstractionPattern getFunctionResultType() const;

lib/AST/ASTVerifier.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ class Verifier : public ASTWalker {
10981098
auto func = Functions.back();
10991099
Type resultType;
11001100
if (auto *FD = dyn_cast<FuncDecl>(func)) {
1101-
resultType = FD->getResultInterfaceType();
1101+
resultType = FD->getResultInterfaceTypeWithoutYields();
11021102
resultType = FD->mapTypeIntoContext(resultType);
11031103
} else if (auto closure = dyn_cast<AbstractClosureExpr>(func)) {
11041104
resultType = closure->getResultType();

lib/AST/Decl.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,18 @@ bool AbstractFunctionDecl::isTransparent() const {
11541154
return false;
11551155
}
11561156

1157+
bool AbstractFunctionDecl::isCoroutine() const {
1158+
// Check if the declaration had the attribute.
1159+
if (getAttrs().hasAttribute<CoroutineAttr>())
1160+
return true;
1161+
1162+
// If this is an accessor, then check if its a coroutine.
1163+
if (const auto *AD = dyn_cast<AccessorDecl>(this))
1164+
return AD->isCoroutine();
1165+
1166+
return false;
1167+
}
1168+
11571169
bool ParameterList::hasInternalParameter(StringRef Prefix) const {
11581170
for (auto param : *this) {
11591171
if (param->hasName() && param->getNameStr().starts_with(Prefix))
@@ -10733,6 +10745,67 @@ Type FuncDecl::getResultInterfaceType() const {
1073310745
return ErrorType::get(ctx);
1073410746
}
1073510747

10748+
Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
10749+
auto resultType = getResultInterfaceType();
10750+
if (resultType->hasError())
10751+
return resultType;
10752+
10753+
// Coroutine result type should either be a yield result
10754+
// or a tuple containing both yielded and normal result types.
10755+
// In both cases, strip the @yield result types
10756+
if (isCoroutine()) {
10757+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
10758+
// Strip @yield results on the first level of tuple
10759+
SmallVector<TupleTypeElt, 4> elements;
10760+
for (const auto &elt : tupleResTy->getElements()) {
10761+
Type eltTy = elt.getType();
10762+
if (eltTy->is<YieldResultType>())
10763+
continue;
10764+
elements.push_back(eltTy);
10765+
}
10766+
10767+
// Handle vanishing tuples -- flatten to produce the
10768+
// element type.
10769+
if (elements.size() == 1)
10770+
resultType = elements[0].getType();
10771+
else
10772+
resultType = TupleType::get(elements, getASTContext());
10773+
} else {
10774+
assert(resultType->is<YieldResultType>());
10775+
resultType = TupleType::getEmpty(getASTContext());
10776+
}
10777+
}
10778+
10779+
return resultType;
10780+
}
10781+
10782+
Type FuncDecl::getYieldsInterfaceType() const {
10783+
auto resultType = getResultInterfaceType();
10784+
if (resultType->hasError())
10785+
return resultType;
10786+
10787+
if (!isCoroutine())
10788+
return TupleType::getEmpty(getASTContext());
10789+
10790+
// Coroutine result type should either be a yield result
10791+
// or a tuple containing both yielded and normal result types.
10792+
// In both cases, strip the @yield result types
10793+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
10794+
// Keep @yield results on the first level of tuple
10795+
for (const auto &elt : tupleResTy->getElements()) {
10796+
Type eltTy = elt.getType();
10797+
if (eltTy->is<YieldResultType>())
10798+
return eltTy;
10799+
}
10800+
10801+
llvm_unreachable("coroutine must have a yield result");
10802+
} else {
10803+
assert(resultType->is<YieldResultType>());
10804+
}
10805+
10806+
return resultType;
10807+
}
10808+
1073610809
bool FuncDecl::isUnaryOperator() const {
1073710810
if (!isOperator())
1073810811
return false;

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,13 @@ AbstractionPattern::getParameterizedProtocolArgType(unsigned argIndex) const {
10381038
cast<ParameterizedProtocolType>(getType()).getArgs()[argIndex]);
10391039
}
10401040

1041+
AbstractionPattern AbstractionPattern::getYieldResultType() const {
1042+
assert(getKind() == Kind::Type);
1043+
return AbstractionPattern(getGenericSubstitutions(),
1044+
getGenericSignature(),
1045+
cast<YieldResultType>(getType()).getResultType());
1046+
}
1047+
10411048
AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
10421049
switch (getKind()) {
10431050
case Kind::Invalid:
@@ -2715,6 +2722,13 @@ class SubstFunctionTypePatternVisitor
27152722
llvm_unreachable("shouldn't encounter pack element by itself");
27162723
}
27172724

2725+
CanType visitYieldResultType(CanYieldResultType yield,
2726+
AbstractionPattern pattern) {
2727+
auto resultType = visit(yield.getResultType(), pattern.getYieldResultType());
2728+
return YieldResultType::get(resultType, yield->isInOut())
2729+
->getCanonicalType();
2730+
}
2731+
27182732
CanType handlePackExpansion(AbstractionPattern origExpansion,
27192733
CanType candidateSubstType) {
27202734
// When we're within a pack expansion, pack references matching that

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,10 @@ class DestructureResults {
13781378
return;
13791379
}
13801380

1381+
// Skip yields, they should've already processed elsewhere
1382+
if (isa<YieldResultType>(substType))
1383+
return;
1384+
13811385
auto &substResultTLForConvention = TC.getTypeLowering(
13821386
origType, substType, TypeExpansionContext::minimal());
13831387
auto &substResultTL = TC.getTypeLowering(origType, substType,
@@ -2221,23 +2225,22 @@ lowerCaptureContextParameters(TypeConverter &TC, SILDeclRef function,
22212225
"iterating over loweredCaptures.getCaptures().");
22222226
}
22232227

2224-
static AccessorDecl *
2225-
getAsCoroutineAccessor(std::optional<SILDeclRef> constant) {
2228+
static FuncDecl *getAsCoroutine(std::optional<SILDeclRef> constant) {
22262229
if (!constant || !constant->hasDecl())
22272230
return nullptr;;
22282231

2229-
auto accessor = dyn_cast<AccessorDecl>(constant->getDecl());
2230-
if (!accessor || !accessor->isCoroutine())
2232+
auto fd = dyn_cast<FuncDecl>(constant->getDecl());
2233+
if (!fd || !fd->isCoroutine())
22312234
return nullptr;
22322235

2233-
return accessor;
2236+
return fd;
22342237
}
22352238

22362239
static void destructureYieldsForReadAccessor(TypeConverter &TC,
2237-
TypeExpansionContext expansion,
2238-
AbstractionPattern origType,
2239-
CanType valueType,
2240-
SmallVectorImpl<SILYieldInfo> &yields){
2240+
TypeExpansionContext expansion,
2241+
AbstractionPattern origType,
2242+
CanType valueType,
2243+
SmallVectorImpl<SILYieldInfo> &yields){
22412244
// Recursively destructure tuples.
22422245
if (origType.isTuple()) {
22432246
auto valueTupleType = cast<TupleType>(valueType);
@@ -2269,25 +2272,19 @@ static void destructureYieldsForReadAccessor(TypeConverter &TC,
22692272

22702273
static void destructureYieldsForCoroutine(TypeConverter &TC,
22712274
TypeExpansionContext expansion,
2272-
std::optional<SILDeclRef> constant,
22732275
AbstractionPattern origType,
22742276
CanType canValueType,
2275-
SmallVectorImpl<SILYieldInfo> &yields,
2276-
SILCoroutineKind &coroutineKind) {
2277-
auto accessor = getAsCoroutineAccessor(constant);
2278-
if (!accessor)
2279-
return;
2280-
2277+
bool isInOutYield,
2278+
SmallVectorImpl<SILYieldInfo> &yields) {
22812279
// 'modify' yields an inout of the target type.
2282-
if (isYieldingMutableAccessor(accessor->getAccessorKind())) {
2280+
if (isInOutYield) {
22832281
auto loweredValueTy =
22842282
TC.getLoweredType(origType, canValueType, expansion);
22852283
yields.push_back(SILYieldInfo(loweredValueTy.getASTType(),
22862284
ParameterConvention::Indirect_Inout));
22872285
} else {
22882286
// 'read' yields a borrowed value of the target type, destructuring
22892287
// tuples as necessary.
2290-
assert(isYieldingImmutableAccessor(accessor->getAccessorKind()));
22912288
destructureYieldsForReadAccessor(TC, expansion, origType, canValueType,
22922289
yields);
22932290
}
@@ -2377,35 +2374,42 @@ static CanSILFunctionType getSILFunctionType(
23772374

23782375
bool hasSendingResult = substFnInterfaceType->getExtInfo().hasSendingResult();
23792376

2380-
// Get the yield type for an accessor coroutine.
2377+
// Get the yield type for coroutine.
23812378
SILCoroutineKind coroutineKind = SILCoroutineKind::None;
23822379
AbstractionPattern coroutineOrigYieldType = AbstractionPattern::getInvalid();
23832380
CanType coroutineSubstYieldType;
23842381

2385-
if (auto accessor = getAsCoroutineAccessor(constant)) {
2386-
auto origAccessor = cast<AccessorDecl>(origConstant->getDecl());
2387-
coroutineKind =
2382+
bool isInOutYield = false;
2383+
if (auto fd = getAsCoroutine(constant)) {
2384+
auto origFd = cast<FuncDecl>(origConstant->getDecl());
2385+
if (auto accessor = dyn_cast<AccessorDecl>(origFd)) {
2386+
coroutineKind =
23882387
requiresFeatureCoroutineAccessors(accessor->getAccessorKind())
2389-
? SILCoroutineKind::YieldOnce2
2390-
: SILCoroutineKind::YieldOnce;
2391-
2388+
? SILCoroutineKind::YieldOnce2
2389+
: SILCoroutineKind::YieldOnce;
2390+
} else {
2391+
// FIXME: Decide we'd directly go to YieldOnce2 for non-accessor coroutines
2392+
coroutineKind = SILCoroutineKind::YieldOnce;
2393+
}
2394+
23922395
// Coroutine accessors are always native, so fetch the native
23932396
// abstraction pattern.
2394-
auto origStorage = origAccessor->getStorage();
2395-
coroutineOrigYieldType = TC.getAbstractionPattern(origStorage,
2396-
/*nonobjc*/ true)
2397-
.getReferenceStorageReferentType();
2398-
2399-
auto storage = accessor->getStorage();
2400-
auto valueType = storage->getValueInterfaceType();
2397+
auto sig = origFd->getGenericSignatureOfContext()
2398+
.getCanonicalSignature();
2399+
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();
2400+
auto reducedYieldType = sig.getReducedType(origYieldType->getResultType());
2401+
coroutineOrigYieldType = AbstractionPattern(sig, reducedYieldType);
24012402

2403+
auto yieldType = fd->getYieldsInterfaceType()->castTo<YieldResultType>();
2404+
auto valueType = yieldType->getResultType();
2405+
isInOutYield = yieldType->isInOut();
24022406
if (reqtSubs) {
24032407
valueType = valueType.subst(*reqtSubs);
24042408
coroutineSubstYieldType = valueType->getReducedType(
24052409
genericSig);
24062410
} else {
24072411
coroutineSubstYieldType = valueType->getReducedType(
2408-
accessor->getGenericSignature());
2412+
fd->getGenericSignature());
24092413
}
24102414
}
24112415

@@ -2430,7 +2434,8 @@ static CanSILFunctionType getSILFunctionType(
24302434
// for class override thunks. This is required to make the yields
24312435
// match in abstraction to the base method's yields, which is necessary
24322436
// to make the extracted continuation-function signatures match.
2433-
if (constant != origConstant && getAsCoroutineAccessor(constant))
2437+
if (constant != origConstant &&
2438+
coroutineKind != SILCoroutineKind::None)
24342439
return true;
24352440

24362441
// We don't currently use substituted function types for generic function
@@ -2546,10 +2551,12 @@ static CanSILFunctionType getSILFunctionType(
25462551

25472552
// Destructure the coroutine yields.
25482553
SmallVector<SILYieldInfo, 8> yields;
2549-
destructureYieldsForCoroutine(TC, expansionContext, constant,
2550-
coroutineOrigYieldType, coroutineSubstYieldType,
2551-
yields, coroutineKind);
2552-
2554+
if (coroutineKind != SILCoroutineKind::None) {
2555+
destructureYieldsForCoroutine(TC, expansionContext,
2556+
coroutineOrigYieldType, coroutineSubstYieldType,
2557+
isInOutYield, yields);
2558+
}
2559+
25532560
// Destructure the result tuple type.
25542561
SmallVector<SILResultInfo, 8> results;
25552562
{
@@ -4724,6 +4731,8 @@ TypeConverter::getLoweredFormalTypes(SILDeclRef constant,
47244731
extInfo = extInfo.withThrows(true, innerExtInfo.getThrownError());
47254732
if (innerExtInfo.isAsync())
47264733
extInfo = extInfo.withAsync(true);
4734+
if (innerExtInfo.isCoroutine())
4735+
extInfo = extInfo.withCoroutine(true);
47274736

47284737
// Distributed thunks are always `async throws`
47294738
if (constant.isDistributedThunk()) {

lib/SILGen/SILGenFunction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,15 +1133,15 @@ void SILGenFunction::emitFunction(FuncDecl *fd) {
11331133

11341134
auto captureInfo = SGM.M.Types.getLoweredLocalCaptures(SILDeclRef(fd));
11351135
emitProlog(fd, captureInfo, fd->getParameters(), fd->getImplicitSelfDecl(),
1136-
fd->getResultInterfaceType(), fd->getEffectiveThrownErrorType(),
1136+
fd->getResultInterfaceTypeWithoutYields(), fd->getEffectiveThrownErrorType(),
11371137
fd->getThrowsLoc());
11381138

11391139
if (fd->isDistributedActorFactory()) {
11401140
// Synthesize the factory function body
11411141
emitDistributedActorFactory(fd);
11421142
} else {
11431143
prepareEpilog(fd,
1144-
fd->getResultInterfaceType(),
1144+
fd->getResultInterfaceTypeWithoutYields(),
11451145
fd->getEffectiveThrownErrorType(),
11461146
CleanupLocation(fd));
11471147

lib/SILGen/SILGenProlog.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,10 @@ static void emitIndirectResultParameters(SILGenFunction &SGF,
15391539

15401540
assert(!resultType->is<PackExpansionType>());
15411541

1542+
// Skip yields, they are emitted elsewhere
1543+
if (resultType->is<YieldResultType>())
1544+
return;
1545+
15421546
// If the return type is address-only, emit the indirect return argument.
15431547

15441548
// The calling convention always uses minimal resilience expansion.
@@ -1625,7 +1629,7 @@ uint16_t SILGenFunction::emitBasicProlog(
16251629
? origClosureType->getFunctionResultType()
16261630
: AbstractionPattern(genericSig.getCanonicalSignature(),
16271631
resultType->getCanonicalType());
1628-
1632+
16291633
emitIndirectResultParameters(*this, resultType, origResultType, DC);
16301634

16311635
std::optional<AbstractionPattern> origErrorType;

0 commit comments

Comments
 (0)