Skip to content

Commit 58bfe07

Browse files
committed
First cut of making coroutine AST type
1 parent d124554 commit 58bfe07

32 files changed

+265
-110
lines changed

include/swift/AST/AnyFunctionRef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class AnyFunctionRef {
9696
Type getBodyResultType() const {
9797
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
9898
if (auto *FD = dyn_cast<FuncDecl>(AFD))
99-
return FD->mapTypeIntoContext(FD->getResultInterfaceType());
99+
return FD->mapTypeIntoContext(FD->getResultInterfaceTypeWithoutYields());
100100
return TupleType::getEmpty(AFD->getASTContext());
101101
}
102102
return cast<AbstractClosureExpr *>(TheFunction)->getResultType();

include/swift/AST/Decl.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8508,9 +8508,15 @@ class FuncDecl : public AbstractFunctionDecl {
85088508
return FnRetType.getSourceRange();
85098509
}
85108510

8511-
/// Retrieve the result interface type of this function.
8511+
/// Retrieve the full result interface type of this function, including yields
85128512
Type getResultInterfaceType() const;
85138513

8514+
/// Same as above, but without @yields
8515+
Type getResultInterfaceTypeWithoutYields() const;
8516+
8517+
/// Same as above, but only yields
8518+
Type getYieldsInterfaceType() const;
8519+
85148520
/// Returns the result interface type of this function if it has already been
85158521
/// computed, otherwise `nullopt`. This should only be used for dumping.
85168522
std::optional<Type> getCachedResultInterfaceType() const;

include/swift/AST/Types.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,10 +1724,13 @@ class YieldResultType : public TypeBase {
17241724
};
17251725

17261726
BEGIN_CAN_TYPE_WRAPPER(YieldResultType, Type)
1727-
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1728-
static CanYieldResultType get(CanType type, bool InOut) {
1729-
return CanYieldResultType(YieldResultType::get(type, InOut));
1730-
}
1727+
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1728+
bool isInOut() const {
1729+
return getPointer()->isInOut();
1730+
}
1731+
static CanYieldResultType get(CanType type, bool InOut) {
1732+
return CanYieldResultType(YieldResultType::get(type, InOut));
1733+
}
17311734
END_CAN_TYPE_WRAPPER(YieldResultType, Type)
17321735

17331736
/// BuiltinType - An abstract class for all the builtin types.

include/swift/SIL/AbstractionPattern.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,10 @@ class AbstractionPattern {
15241524
/// type, return the abstraction pattern for one of its argument types.
15251525
AbstractionPattern getParameterizedProtocolArgType(unsigned i) const;
15261526

1527+
/// Given that the value being abstracted is a yield result type,
1528+
/// return the abstraction pattern for corresponding type.
1529+
AbstractionPattern getYieldResultType() const;
1530+
15271531
/// Given that the value being abstracted is a function, return the
15281532
/// abstraction pattern for its result type.
15291533
AbstractionPattern getFunctionResultType() const;

lib/AST/ASTVerifier.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ class Verifier : public ASTWalker {
10981098
auto func = Functions.back();
10991099
Type resultType;
11001100
if (auto *FD = dyn_cast<FuncDecl>(func)) {
1101-
resultType = FD->getResultInterfaceType();
1101+
resultType = FD->getResultInterfaceTypeWithoutYields();
11021102
resultType = FD->mapTypeIntoContext(resultType);
11031103
} else if (auto closure = dyn_cast<AbstractClosureExpr>(func)) {
11041104
resultType = closure->getResultType();

lib/AST/Decl.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,18 @@ bool AbstractFunctionDecl::isTransparent() const {
13741374
return false;
13751375
}
13761376

1377+
bool AbstractFunctionDecl::isCoroutine() const {
1378+
// Check if the declaration had the attribute.
1379+
if (getAttrs().hasAttribute<CoroutineAttr>())
1380+
return true;
1381+
1382+
// If this is an accessor, then check if its a coroutine.
1383+
if (const auto *AD = dyn_cast<AccessorDecl>(this))
1384+
return AD->isCoroutine();
1385+
1386+
return false;
1387+
}
1388+
13771389
bool ParameterList::hasInternalParameter(StringRef Prefix) const {
13781390
for (auto param : *this) {
13791391
if (param->hasName() && param->getNameStr().starts_with(Prefix))
@@ -11352,6 +11364,67 @@ std::optional<Type> FuncDecl::getCachedResultInterfaceType() const {
1135211364
return ResultTypeRequest{mutableThis}.getCachedResult();
1135311365
}
1135411366

11367+
Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
11368+
auto resultType = getResultInterfaceType();
11369+
if (resultType->hasError())
11370+
return resultType;
11371+
11372+
// Coroutine result type should either be a yield result
11373+
// or a tuple containing both yielded and normal result types.
11374+
// In both cases, strip the @yield result types
11375+
if (isCoroutine()) {
11376+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
11377+
// Strip @yield results on the first level of tuple
11378+
SmallVector<TupleTypeElt, 4> elements;
11379+
for (const auto &elt : tupleResTy->getElements()) {
11380+
Type eltTy = elt.getType();
11381+
if (eltTy->is<YieldResultType>())
11382+
continue;
11383+
elements.push_back(eltTy);
11384+
}
11385+
11386+
// Handle vanishing tuples -- flatten to produce the
11387+
// element type.
11388+
if (elements.size() == 1)
11389+
resultType = elements[0].getType();
11390+
else
11391+
resultType = TupleType::get(elements, getASTContext());
11392+
} else {
11393+
assert(resultType->is<YieldResultType>());
11394+
resultType = TupleType::getEmpty(getASTContext());
11395+
}
11396+
}
11397+
11398+
return resultType;
11399+
}
11400+
11401+
Type FuncDecl::getYieldsInterfaceType() const {
11402+
auto resultType = getResultInterfaceType();
11403+
if (resultType->hasError())
11404+
return resultType;
11405+
11406+
if (!isCoroutine())
11407+
return TupleType::getEmpty(getASTContext());
11408+
11409+
// Coroutine result type should either be a yield result
11410+
// or a tuple containing both yielded and normal result types.
11411+
// In both cases, strip the @yield result types
11412+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
11413+
// Keep @yield results on the first level of tuple
11414+
for (const auto &elt : tupleResTy->getElements()) {
11415+
Type eltTy = elt.getType();
11416+
if (eltTy->is<YieldResultType>())
11417+
return eltTy;
11418+
}
11419+
11420+
llvm_unreachable("coroutine must have a yield result");
11421+
} else {
11422+
assert(resultType->is<YieldResultType>());
11423+
}
11424+
11425+
return resultType;
11426+
}
11427+
1135511428
bool FuncDecl::isUnaryOperator() const {
1135611429
if (!isOperator())
1135711430
return false;

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,13 @@ AbstractionPattern::getParameterizedProtocolArgType(unsigned argIndex) const {
10381038
cast<ParameterizedProtocolType>(getType()).getArgs()[argIndex]);
10391039
}
10401040

1041+
AbstractionPattern AbstractionPattern::getYieldResultType() const {
1042+
assert(getKind() == Kind::Type);
1043+
return AbstractionPattern(getGenericSubstitutions(),
1044+
getGenericSignature(),
1045+
cast<YieldResultType>(getType()).getResultType());
1046+
}
1047+
10411048
AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
10421049
switch (getKind()) {
10431050
case Kind::Invalid:
@@ -2731,6 +2738,13 @@ class SubstFunctionTypePatternVisitor
27312738
llvm_unreachable("shouldn't encounter pack element by itself");
27322739
}
27332740

2741+
CanType visitYieldResultType(CanYieldResultType yield,
2742+
AbstractionPattern pattern) {
2743+
auto resultType = visit(yield.getResultType(), pattern.getYieldResultType());
2744+
return YieldResultType::get(resultType, yield->isInOut())
2745+
->getCanonicalType();
2746+
}
2747+
27342748
CanType handlePackExpansion(AbstractionPattern origExpansion,
27352749
CanType candidateSubstType) {
27362750
// When we're within a pack expansion, pack references matching that

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,10 @@ class DestructureResults {
14171417
return;
14181418
}
14191419

1420+
// Skip yields, they should've already processed elsewhere
1421+
if (isa<YieldResultType>(substType))
1422+
return;
1423+
14201424
auto &substResultTLForConvention = TC.getTypeLowering(
14211425
origType, substType, TypeExpansionContext::minimal());
14221426
auto &substResultTL = TC.getTypeLowering(origType, substType,
@@ -2317,23 +2321,22 @@ lowerCaptureContextParameters(TypeConverter &TC, SILDeclRef function,
23172321
"iterating over loweredCaptures.getCaptures().");
23182322
}
23192323

2320-
static AccessorDecl *
2321-
getAsCoroutineAccessor(std::optional<SILDeclRef> constant) {
2324+
static FuncDecl *getAsCoroutine(std::optional<SILDeclRef> constant) {
23222325
if (!constant || !constant->hasDecl())
23232326
return nullptr;;
23242327

2325-
auto accessor = dyn_cast<AccessorDecl>(constant->getDecl());
2326-
if (!accessor || !accessor->isCoroutine())
2328+
auto fd = dyn_cast<FuncDecl>(constant->getDecl());
2329+
if (!fd || !fd->isCoroutine())
23272330
return nullptr;
23282331

2329-
return accessor;
2332+
return fd;
23302333
}
23312334

23322335
static void destructureYieldsForReadAccessor(TypeConverter &TC,
2333-
TypeExpansionContext expansion,
2334-
AbstractionPattern origType,
2335-
CanType valueType,
2336-
SmallVectorImpl<SILYieldInfo> &yields){
2336+
TypeExpansionContext expansion,
2337+
AbstractionPattern origType,
2338+
CanType valueType,
2339+
SmallVectorImpl<SILYieldInfo> &yields){
23372340
// Recursively destructure tuples.
23382341
if (origType.isTuple()) {
23392342
auto valueTupleType = cast<TupleType>(valueType);
@@ -2365,25 +2368,19 @@ static void destructureYieldsForReadAccessor(TypeConverter &TC,
23652368

23662369
static void destructureYieldsForCoroutine(TypeConverter &TC,
23672370
TypeExpansionContext expansion,
2368-
std::optional<SILDeclRef> constant,
23692371
AbstractionPattern origType,
23702372
CanType canValueType,
2371-
SmallVectorImpl<SILYieldInfo> &yields,
2372-
SILCoroutineKind &coroutineKind) {
2373-
auto accessor = getAsCoroutineAccessor(constant);
2374-
if (!accessor)
2375-
return;
2376-
2373+
bool isInOutYield,
2374+
SmallVectorImpl<SILYieldInfo> &yields) {
23772375
// 'modify' yields an inout of the target type.
2378-
if (isYieldingMutableAccessor(accessor->getAccessorKind())) {
2376+
if (isInOutYield) {
23792377
auto loweredValueTy =
23802378
TC.getLoweredType(origType, canValueType, expansion);
23812379
yields.push_back(SILYieldInfo(loweredValueTy.getASTType(),
23822380
ParameterConvention::Indirect_Inout));
23832381
} else {
23842382
// 'read' yields a borrowed value of the target type, destructuring
23852383
// tuples as necessary.
2386-
assert(isYieldingImmutableAccessor(accessor->getAccessorKind()));
23872384
destructureYieldsForReadAccessor(TC, expansion, origType, canValueType,
23882385
yields);
23892386
}
@@ -2511,37 +2508,44 @@ static CanSILFunctionType getSILFunctionType(
25112508

25122509
bool hasSendingResult = substFnInterfaceType->getExtInfo().hasSendingResult();
25132510

2514-
// Get the yield type for an accessor coroutine.
2511+
// Get the yield type for coroutine.
25152512
SILCoroutineKind coroutineKind = SILCoroutineKind::None;
25162513
AbstractionPattern coroutineOrigYieldType = AbstractionPattern::getInvalid();
25172514
CanType coroutineSubstYieldType;
25182515

2519-
if (auto accessor = getAsCoroutineAccessor(constant)) {
2520-
auto origAccessor = cast<AccessorDecl>(origConstant->getDecl());
2521-
auto &ctx = origAccessor->getASTContext();
2522-
coroutineKind =
2516+
bool isInOutYield = false;
2517+
if (auto fd = getAsCoroutine(constant)) {
2518+
auto origFd = cast<FuncDecl>(origConstant->getDecl());
2519+
auto &ctx = origFd->getASTContext();
2520+
if (auto accessor = dyn_cast<AccessorDecl>(origFd)) {
2521+
coroutineKind =
25232522
(requiresFeatureCoroutineAccessors(accessor->getAccessorKind()) &&
25242523
ctx.SILOpts.CoroutineAccessorsUseYieldOnce2)
2525-
? SILCoroutineKind::YieldOnce2
2526-
: SILCoroutineKind::YieldOnce;
2527-
2524+
? SILCoroutineKind::YieldOnce2
2525+
: SILCoroutineKind::YieldOnce;
2526+
} else {
2527+
// FIXME: Decide we'd directly go to YieldOnce2 for non-accessor coroutines
2528+
coroutineKind = SILCoroutineKind::YieldOnce;
2529+
}
2530+
25282531
// Coroutine accessors are always native, so fetch the native
25292532
// abstraction pattern.
2530-
auto origStorage = origAccessor->getStorage();
2531-
coroutineOrigYieldType = TC.getAbstractionPattern(origStorage,
2532-
/*nonobjc*/ true)
2533-
.getReferenceStorageReferentType();
2534-
2535-
auto storage = accessor->getStorage();
2536-
auto valueType = storage->getValueInterfaceType();
2533+
auto sig = origFd->getGenericSignatureOfContext()
2534+
.getCanonicalSignature();
2535+
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();
2536+
auto reducedYieldType = sig.getReducedType(origYieldType->getResultType());
2537+
coroutineOrigYieldType = AbstractionPattern(sig, reducedYieldType);
25372538

2539+
auto yieldType = fd->getYieldsInterfaceType()->castTo<YieldResultType>();
2540+
auto valueType = yieldType->getResultType();
2541+
isInOutYield = yieldType->isInOut();
25382542
if (reqtSubs) {
25392543
valueType = valueType.subst(*reqtSubs);
25402544
coroutineSubstYieldType = valueType->getReducedType(
25412545
genericSig);
25422546
} else {
25432547
coroutineSubstYieldType = valueType->getReducedType(
2544-
accessor->getGenericSignature());
2548+
fd->getGenericSignature());
25452549
}
25462550
}
25472551

@@ -2566,7 +2570,8 @@ static CanSILFunctionType getSILFunctionType(
25662570
// for class override thunks. This is required to make the yields
25672571
// match in abstraction to the base method's yields, which is necessary
25682572
// to make the extracted continuation-function signatures match.
2569-
if (constant != origConstant && getAsCoroutineAccessor(constant))
2573+
if (constant != origConstant &&
2574+
coroutineKind != SILCoroutineKind::None)
25702575
return true;
25712576

25722577
// We don't currently use substituted function types for generic function
@@ -2670,10 +2675,12 @@ static CanSILFunctionType getSILFunctionType(
26702675

26712676
// Destructure the coroutine yields.
26722677
SmallVector<SILYieldInfo, 8> yields;
2673-
destructureYieldsForCoroutine(TC, expansionContext, constant,
2674-
coroutineOrigYieldType, coroutineSubstYieldType,
2675-
yields, coroutineKind);
2676-
2678+
if (coroutineKind != SILCoroutineKind::None) {
2679+
destructureYieldsForCoroutine(TC, expansionContext,
2680+
coroutineOrigYieldType, coroutineSubstYieldType,
2681+
isInOutYield, yields);
2682+
}
2683+
26772684
// Destructure the result tuple type.
26782685
SmallVector<SILResultInfo, 8> results;
26792686
{
@@ -4977,6 +4984,8 @@ TypeConverter::getLoweredFormalTypes(SILDeclRef constant,
49774984
extInfo = extInfo.withThrows(true, innerExtInfo.getThrownError());
49784985
if (innerExtInfo.isAsync())
49794986
extInfo = extInfo.withAsync(true);
4987+
if (innerExtInfo.isCoroutine())
4988+
extInfo = extInfo.withCoroutine(true);
49804989

49814990
// Distributed thunks are always `async throws`
49824991
if (constant.isDistributedThunk()) {

lib/SILGen/SILGenFunction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,15 +1121,15 @@ void SILGenFunction::emitFunction(FuncDecl *fd) {
11211121

11221122
auto captureInfo = SGM.M.Types.getLoweredLocalCaptures(SILDeclRef(fd));
11231123
emitProlog(fd, captureInfo, fd->getParameters(), fd->getImplicitSelfDecl(),
1124-
fd->getResultInterfaceType(), fd->getEffectiveThrownErrorType(),
1124+
fd->getResultInterfaceTypeWithoutYields(), fd->getEffectiveThrownErrorType(),
11251125
fd->getThrowsLoc());
11261126

11271127
if (fd->isDistributedActorFactory()) {
11281128
// Synthesize the factory function body
11291129
emitDistributedActorFactory(fd);
11301130
} else {
11311131
prepareEpilog(fd,
1132-
fd->getResultInterfaceType(),
1132+
fd->getResultInterfaceTypeWithoutYields(),
11331133
fd->getEffectiveThrownErrorType(),
11341134
CleanupLocation(fd));
11351135

lib/SILGen/SILGenProlog.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,10 @@ static void emitIndirectResultParameters(SILGenFunction &SGF,
15681568

15691569
assert(!resultType->is<PackExpansionType>());
15701570

1571+
// Skip yields, they are emitted elsewhere
1572+
if (resultType->is<YieldResultType>())
1573+
return;
1574+
15711575
// If the return type is address-only, emit the indirect return argument.
15721576

15731577
// The calling convention always uses minimal resilience expansion.

0 commit comments

Comments
 (0)