Skip to content

Commit e38dae3

Browse files
committed
Initial & rudimentary support for coroutine function types. Some fixes while here
1 parent e658e86 commit e38dae3

File tree

11 files changed

+93
-50
lines changed

11 files changed

+93
-50
lines changed

lib/AST/Decl.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,8 +1171,8 @@ AnyFunctionRef::getYieldResultsImpl(SmallVectorImpl<AnyFunctionType::Yield> &buf
11711171
bool mapIntoContext) const {
11721172
assert(buffer.empty());
11731173
if (auto *AFD = getAbstractFunctionDecl()) {
1174+
// FIXME: AccessorDecl case is not necessary
11741175
if (auto *AD = dyn_cast<AccessorDecl>(AFD)) {
1175-
// FIXME: AccessorDecl case is not necessary
11761176
if (AD->isCoroutine()) {
11771177
auto valueTy = AD->getStorage()->getValueInterfaceType()
11781178
->getReferenceStorageReferent();
@@ -1185,13 +1185,14 @@ AnyFunctionRef::getYieldResultsImpl(SmallVectorImpl<AnyFunctionType::Yield> &buf
11851185
return buffer;
11861186
}
11871187
} else if (AFD->isCoroutine()) {
1188-
auto resType = AFD->getInterfaceType()->castTo<FunctionType>()->getResult();
1189-
if (auto *resFnType = resType->getAs<FunctionType>())
1190-
resType = resFnType->getResult();
1191-
1192-
if (resType->hasError())
1188+
auto fnType = AFD->getInterfaceType();
1189+
if (fnType->hasError())
11931190
return {};
11941191

1192+
auto resType = fnType->castTo<AnyFunctionType>()->getResult();
1193+
if (auto *resFnType = resType->getAs<AnyFunctionType>())
1194+
resType = resFnType->getResult();
1195+
11951196
auto addYieldInfo =
11961197
[&](const YieldResultType *yieldResultTy) {
11971198
Type valueTy = yieldResultTy->getResultType();
@@ -1208,8 +1209,8 @@ AnyFunctionRef::getYieldResultsImpl(SmallVectorImpl<AnyFunctionType::Yield> &buf
12081209
if (auto *yieldResTy = eltTy->getAs<YieldResultType>())
12091210
addYieldInfo(yieldResTy);
12101211
}
1211-
else
1212-
addYieldInfo(resType->castTo<YieldResultType>());
1212+
else if (auto *yieldResTy = resType->getAs<YieldResultType>())
1213+
addYieldInfo(yieldResTy);
12131214

12141215
return buffer;
12151216
}
@@ -10822,8 +10823,7 @@ Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
1082210823
resultType = elements[0].getType();
1082310824
else
1082410825
resultType = TupleType::get(elements, getASTContext());
10825-
} else {
10826-
assert(resultType->is<YieldResultType>());
10826+
} else if (resultType->is<YieldResultType>()) {
1082710827
resultType = TupleType::getEmpty(getASTContext());
1082810828
}
1082910829
}
@@ -10851,8 +10851,8 @@ Type FuncDecl::getYieldsInterfaceType() const {
1085110851
}
1085210852

1085310853
llvm_unreachable("coroutine must have a yield result");
10854-
} else {
10855-
assert(resultType->is<YieldResultType>());
10854+
} else if (!resultType->is<YieldResultType>()) {
10855+
resultType = TupleType::getEmpty(getASTContext());
1085610856
}
1085710857

1085810858
return resultType;

lib/IRGen/IRGenSIL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4787,13 +4787,13 @@ void IRGenSILFunction::visitEndApply(BeginApplyInst *i, EndApplyInst *ei) {
47874787

47884788
if (!isAbort) {
47894789
auto resultType = call->getType();
4790+
Explosion e;
47904791
if (!resultType->isVoidTy()) {
4791-
Explosion e;
47924792
// FIXME: Do we need to handle ABI-related conversions here?
47934793
// It seems we cannot have C function convention for coroutines, etc.
47944794
extractScalarResults(*this, resultType, call, e);
4795-
setLoweredExplosion(ei, e);
47964795
}
4796+
setLoweredExplosion(ei, e);
47974797
}
47984798

47994799
coroutine.Temporaries.destroyAll(*this);

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,7 @@ static CanSILFunctionType getSILFunctionType(
23802380
CanType coroutineSubstYieldType;
23812381

23822382
bool isInOutYield = false;
2383-
if (auto fd = getAsCoroutine(constant)) {
2383+
if (auto fd = getAsCoroutine(constant)) { // Derive yield type for declaration
23842384
auto origFd = cast<FuncDecl>(origConstant->getDecl());
23852385
if (auto accessor = dyn_cast<AccessorDecl>(origFd)) {
23862386
coroutineKind =
@@ -2410,6 +2410,19 @@ static CanSILFunctionType getSILFunctionType(
24102410
coroutineSubstYieldType = valueType->getReducedType(
24112411
fd->getGenericSignature());
24122412
}
2413+
} else if (substFnInterfaceType->isCoroutine()) { // Derive yield type for function type
2414+
coroutineKind = SILCoroutineKind::YieldOnce;
2415+
auto origYieldType = origType.getFunctionResultType().getType()->castTo<YieldResultType>();
2416+
auto reducedYieldType = genericSig.getReducedType(origYieldType->getResultType());
2417+
coroutineOrigYieldType = AbstractionPattern(genericSig, reducedYieldType);
2418+
2419+
auto yieldType = substFnInterfaceType->getResult()->castTo<YieldResultType>();
2420+
auto valueType = yieldType->getResultType();
2421+
isInOutYield = yieldType->isInOut();
2422+
if (reqtSubs)
2423+
valueType = valueType.subst(*reqtSubs);
2424+
2425+
coroutineSubstYieldType = valueType->getReducedType(genericSig);
24132426
}
24142427

24152428
bool shouldBuildSubstFunctionType = [&]{
@@ -2433,8 +2446,7 @@ static CanSILFunctionType getSILFunctionType(
24332446
// for class override thunks. This is required to make the yields
24342447
// match in abstraction to the base method's yields, which is necessary
24352448
// to make the extracted continuation-function signatures match.
2436-
if (constant != origConstant &&
2437-
coroutineKind != SILCoroutineKind::None)
2449+
if (constant != origConstant && getAsCoroutine(constant))
24382450
return true;
24392451

24402452
// We don't currently use substituted function types for generic function

lib/SILGen/SILGenApply.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6115,17 +6115,19 @@ SILGenFunction::emitBeginApplyWithRethrow(SILLocation loc, SILValue fn,
61156115
return {token, abortCleanup, allocation, deallocCleanup};
61166116
}
61176117

6118-
void SILGenFunction::emitEndApplyWithRethrow(
6118+
SILValue SILGenFunction::emitEndApplyWithRethrow(
61196119
SILLocation loc, MultipleValueInstructionResult *token,
6120-
SILValue allocation) {
6120+
SILValue allocation,
6121+
SILType resultType) {
61216122
// TODO: adjust this to handle results of TryBeginApplyInst.
61226123
assert(token->isBeginApplyToken());
61236124

6124-
B.createEndApply(loc, token,
6125-
SILType::getEmptyTupleType(getASTContext()));
6125+
SILValue result =
6126+
B.createEndApply(loc, token, resultType);
61266127
if (allocation) {
61276128
B.createDeallocStack(loc, allocation);
61286129
}
6130+
return result;
61296131
}
61306132

61316133
void SILGenFunction::emitYield(SILLocation loc,

lib/SILGen/SILGenFunction.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,9 +2196,10 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
21962196
emitBeginApplyWithRethrow(SILLocation loc, SILValue fn, SILType substFnType,
21972197
SubstitutionMap subs, ArrayRef<SILValue> args,
21982198
SmallVectorImpl<SILValue> &yields);
2199-
void emitEndApplyWithRethrow(SILLocation loc,
2200-
MultipleValueInstructionResult *token,
2201-
SILValue allocation);
2199+
SILValue emitEndApplyWithRethrow(SILLocation loc,
2200+
MultipleValueInstructionResult *token,
2201+
SILValue allocation,
2202+
SILType resultType);
22022203

22032204
ManagedValue emitExtractFunctionIsolation(SILLocation loc,
22042205
ArgumentSource &&fnValue);

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5541,6 +5541,9 @@ static ManagedValue createThunk(SILGenFunction &SGF,
55415541
assert(expectedType->getLanguage() ==
55425542
fn.getType().castTo<SILFunctionType>()->getLanguage() &&
55435543
"bridging in re-abstraction thunk?");
5544+
// We cannot reabstract coroutines (yet)
5545+
assert(!expectedType->isCoroutine() && !sourceType->isCoroutine() &&
5546+
"cannot reabstract a coroutine");
55445547

55455548
// Declare the thunk.
55465549
SubstitutionMap interfaceSubs;
@@ -5554,6 +5557,7 @@ static ManagedValue createThunk(SILGenFunction &SGF,
55545557
genericEnv,
55555558
interfaceSubs,
55565559
dynamicSelfType);
5560+
55575561
// An actor-isolated non-async function can be converted to an async function
55585562
// by inserting a hop to the global actor.
55595563
CanType globalActorForThunk;
@@ -6849,9 +6853,9 @@ SILGenFunction::emitVTableThunk(SILDeclRef base,
68496853
}
68506854

68516855
// End the inner coroutine normally.
6852-
emitEndApplyWithRethrow(loc, token, allocation);
6856+
result = emitEndApplyWithRethrow(loc, token, allocation,
6857+
SILType::getEmptyTupleType(getASTContext()));
68536858

6854-
result = B.createTuple(loc, {});
68556859
break;
68566860
}
68576861

@@ -7241,9 +7245,8 @@ void SILGenFunction::emitProtocolWitness(
72417245
}
72427246

72437247
// End the inner coroutine normally.
7244-
emitEndApplyWithRethrow(loc, token, allocation);
7245-
7246-
reqtResultValue = B.createTuple(loc, {});
7248+
reqtResultValue = emitEndApplyWithRethrow(loc, token, allocation,
7249+
SILType::getEmptyTupleType(getASTContext()));
72477250
break;
72487251
}
72497252

lib/SILOptimizer/Utils/SILInliner.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,22 @@ class BeginApplySite {
255255
for (auto calleeYield : BeginApply->getYieldedValues()) {
256256
calleeYield->replaceAllUsesWith(SILUndef::get(calleeYield));
257257
}
258+
259+
if (EndApply)
260+
EndApply->replaceAllUsesWith(SILUndef::get(EndApply));
258261
}
259262

260263
// Remove the resumption sites.
261-
if (EndApply)
264+
if (EndApply) {
265+
// All potential users of end_apply should've been replaced above. The only
266+
// case where we might end with more users is when end_apply itself is
267+
// unreachable. Make sure the function is well-formed and replace the
268+
// results with undef.
269+
if (!EndApply->use_empty())
270+
EndApply->replaceAllUsesWith(SILUndef::get(EndApply));
271+
262272
EndApply->eraseFromParent();
273+
}
263274
if (AbortApply)
264275
AbortApply->eraseFromParent();
265276
for (auto *EndBorrow : EndBorrows)

lib/Sema/CSSimplify.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7211,9 +7211,26 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
72117211

72127212
case TypeKind::Error:
72137213
case TypeKind::Unresolved:
7214-
case TypeKind::YieldResult:
72157214
return getTypeMatchFailure(locator);
72167215

7216+
case TypeKind::YieldResult: {
7217+
if (simplifyType(desugar1)->isEqual(simplifyType(desugar2)))
7218+
return getTypeMatchSuccess();
7219+
7220+
if (kind != ConstraintKind::Bind)
7221+
return getTypeMatchFailure(locator);
7222+
7223+
auto *yield1 = cast<YieldResultType>(desugar1);
7224+
auto *yield2 = cast<YieldResultType>(desugar2);
7225+
7226+
if (yield1->isInOut() != yield2->isInOut())
7227+
return getTypeMatchFailure(locator);
7228+
7229+
return matchTypes(yield1->getResultType(), yield2->getResultType(),
7230+
ConstraintKind::Bind, subflags,
7231+
locator.withPathElement(ConstraintLocator::LValueConversion));
7232+
}
7233+
72177234
case TypeKind::Placeholder: {
72187235
// If it's allowed to attempt fixes, let's delegate
72197236
// decision to `repairFailures`, since depending on
@@ -8213,7 +8230,6 @@ ConstraintSystem::simplifyConstructionConstraint(
82138230
case TypeKind::Unresolved:
82148231
case TypeKind::Error:
82158232
case TypeKind::Placeholder:
8216-
case TypeKind::YieldResult:
82178233
return SolutionKind::Error;
82188234

82198235
case TypeKind::GenericFunction:
@@ -8313,6 +8329,7 @@ ConstraintSystem::simplifyConstructionConstraint(
83138329
case TypeKind::Function:
83148330
case TypeKind::LValue:
83158331
case TypeKind::InOut:
8332+
case TypeKind::YieldResult:
83168333
case TypeKind::Module:
83178334
case TypeKind::Pack:
83188335
case TypeKind::PackExpansion:

lib/Sema/TypeCheckGeneric.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,12 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator,
777777
}
778778
}();
779779
if (resultTypeRepr && !resultTypeRepr->hasOpaque()) {
780+
bool isCoroutine = func ? func->isCoroutine() : false;
781+
TypeResolutionOptions resultOptions(TypeResolverContext::FunctionResult);
782+
if (isCoroutine)
783+
resultOptions |= TypeResolutionFlags::Coroutine;
780784
const auto resultType =
781-
resolution.withOptions(TypeResolverContext::FunctionResult)
785+
resolution.withOptions(resultOptions)
782786
.resolveType(resultTypeRepr);
783787

784788
inferenceSources.push_back(resultType.getPointer());

test/SILGen/modify.swift

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,9 @@ extension Derived : Abstractable {}
8080
// CHECK-NEXT: [[T1:%.*]] = partial_apply [callee_guaranteed] [[REABSTRACTOR]]([[CVT_FN]])
8181
// CHECK-NEXT: store [[T1]] to [init] [[SUPER_ADDR]]
8282
// CHECK-NEXT: dealloc_stack [[SUB_ADDR]]
83-
// CHECK-NEXT: end_apply [[TOKEN]]
84-
// CHECK-NEXT: tuple ()
83+
// CHECK-NEXT: [[TUPLE:%.*]] = end_apply [[TOKEN]] as $()
8584
// CHECK-NEXT: end_borrow [[T0]]
86-
// CHECK-NEXT: return
85+
// CHECK-NEXT: return [[TUPLE]]
8786

8887
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s6modify7DerivedCAA12AbstractableA2aDP19finalStoredFunction6ResultQzycvMTW
8988
// CHECK: bb0(%0 : $*Derived):
@@ -108,10 +107,9 @@ extension Derived : Abstractable {}
108107
// CHECK-NEXT: [[T1:%.*]] = partial_apply [callee_guaranteed] [[REABSTRACTOR]]([[CVT_FN]])
109108
// CHECK-NEXT: store [[T1]] to [init] [[SUPER_ADDR]]
110109
// CHECK-NEXT: dealloc_stack [[SUB_ADDR]]
111-
// CHECK-NEXT: end_apply [[TOKEN]]
112-
// CHECK-NEXT: tuple ()
110+
// CHECK-NEXT: [[TUPLE:%.*]] = end_apply [[TOKEN]] as $()
113111
// CHECK-NEXT: end_borrow [[T0]]
114-
// CHECK-NEXT: return
112+
// CHECK-NEXT: return [[TUPLE]]
115113

116114
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s6modify7DerivedCAA12AbstractableA2aDP14staticFunction6ResultQzycvMZTW
117115
// CHECK: bb0(%0 : $@thick Derived.Type):
@@ -135,9 +133,8 @@ extension Derived : Abstractable {}
135133
// CHECK-NEXT: [[T1:%.*]] = partial_apply [callee_guaranteed] [[REABSTRACTOR]]([[CVT_FN]])
136134
// CHECK-NEXT: store [[T1]] to [init] [[SUPER_ADDR]]
137135
// CHECK-NEXT: dealloc_stack [[SUB_ADDR]]
138-
// CHECK-NEXT: end_apply [[TOKEN]]
139-
// CHECK-NEXT: tuple ()
140-
// CHECK-NEXT: return
136+
// CHECK-NEXT: [[TUPLE:%.*]] = end_apply [[TOKEN]] as $()
137+
// CHECK-NEXT: return [[TUPLE]]
141138

142139
protocol ClassAbstractable : class {
143140
associatedtype Result
@@ -298,8 +295,7 @@ struct Bill : Totalled {
298295
// CHECK-NEXT: ([[T1:%.*]], [[TOKEN:%.*]]) = begin_apply [[T0]]([[SELF]])
299296
// CHECK-NEXT: yield [[T1]] : $*Int, resume bb1, unwind bb2
300297
// CHECK: bb1:
301-
// CHECK-NEXT: end_apply [[TOKEN]]
302-
// CHECK-NEXT: [[T1:%.*]] = tuple ()
298+
// CHECK-NEXT: [[T1:%.*]] = end_apply [[TOKEN]] as $()
303299
// CHECK-NEXT: return [[T1]] :
304300

305301
protocol AddressOnlySubscript {

0 commit comments

Comments
 (0)