Skip to content

Commit 2ebe7bc

Browse files
committed
Properly substitute coroutines
1 parent 9acaa14 commit 2ebe7bc

File tree

9 files changed

+110
-45
lines changed

9 files changed

+110
-45
lines changed

include/swift/AST/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3904,6 +3904,9 @@ class AnyFunctionType : public TypeBase {
39043904
/// Return the function type setting sendable to \p newValue.
39053905
AnyFunctionType *withSendable(bool newValue) const;
39063906

3907+
/// Return the function type without yields (and coroutine flag)
3908+
AnyFunctionType *getWithoutYields() const;
3909+
39073910
/// True if the parameter declaration it is attached to is guaranteed
39083911
/// to not persist the closure for longer than the duration of the call.
39093912
bool isNoEscape() const {

include/swift/SIL/AbstractionPattern.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,13 +1524,9 @@ class AbstractionPattern {
15241524
/// type, return the abstraction pattern for one of its argument types.
15251525
AbstractionPattern getParameterizedProtocolArgType(unsigned i) const;
15261526

1527-
/// Given that the value being abstracted is a yield result type,
1528-
/// return the abstraction pattern for corresponding type.
1529-
AbstractionPattern getYieldResultType() const;
1530-
15311527
/// Given that the value being abstracted is a function, return the
15321528
/// abstraction pattern for its result type.
1533-
AbstractionPattern getFunctionResultType() const;
1529+
AbstractionPattern getFunctionResultType(bool withoutYields = false) const;
15341530

15351531
/// Given that the value being abstracted is a function, return the
15361532
/// abstraction pattern for its thrown error type.

lib/AST/Decl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11432,7 +11432,7 @@ Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
1143211432
Type eltTy = elt.getType();
1143311433
if (eltTy->is<YieldResultType>())
1143411434
continue;
11435-
elements.push_back(eltTy);
11435+
elements.push_back(elt);
1143611436
}
1143711437

1143811438
// Handle vanishing tuples -- flatten to produce the

lib/AST/Type.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4680,6 +4680,39 @@ AnyFunctionType *AnyFunctionType::withSendable(bool newValue) const {
46804680
return withExtInfo(info);
46814681
}
46824682

4683+
AnyFunctionType *AnyFunctionType::getWithoutYields() const {
4684+
auto resultType = getResult();
4685+
4686+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
4687+
// Strip @yield results on the first level of tuple
4688+
SmallVector<TupleTypeElt, 4> elements;
4689+
for (const auto &elt : tupleResTy->getElements()) {
4690+
Type eltTy = elt.getType();
4691+
if (eltTy->is<YieldResultType>())
4692+
continue;
4693+
elements.push_back(elt);
4694+
}
4695+
4696+
// Handle vanishing tuples -- flatten to produce the
4697+
// normal coroutine result type
4698+
if (elements.size() == 1 && isCoroutine())
4699+
resultType = elements[0].getType();
4700+
else
4701+
resultType = TupleType::get(elements, getASTContext());
4702+
} else if (resultType->is<YieldResultType>()) {
4703+
resultType = TupleType::getEmpty(getASTContext());
4704+
}
4705+
4706+
auto noCoroExtInfo = getExtInfo().intoBuilder()
4707+
.withCoroutine(false)
4708+
.build();
4709+
if (isa<FunctionType>(this))
4710+
return FunctionType::get(getParams(), resultType, noCoroExtInfo);
4711+
assert(isa<GenericFunctionType>(this));
4712+
return GenericFunctionType::get(getOptGenericSignature(), getParams(),
4713+
resultType, noCoroExtInfo);
4714+
}
4715+
46834716
std::optional<Type> AnyFunctionType::getEffectiveThrownErrorType() const {
46844717
// A non-throwing function... has no thrown interface type.
46854718
if (!isThrowing())

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,13 +1038,6 @@ AbstractionPattern::getParameterizedProtocolArgType(unsigned argIndex) const {
10381038
cast<ParameterizedProtocolType>(getType()).getArgs()[argIndex]);
10391039
}
10401040

1041-
AbstractionPattern AbstractionPattern::getYieldResultType() const {
1042-
assert(getKind() == Kind::Type);
1043-
return AbstractionPattern(getGenericSubstitutions(),
1044-
getGenericSignature(),
1045-
cast<YieldResultType>(getType()).getResultType());
1046-
}
1047-
10481041
AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
10491042
switch (getKind()) {
10501043
case Kind::Invalid:
@@ -1178,11 +1171,15 @@ AbstractionPattern::getCXXMethodSelfPattern(CanType selfType) const {
11781171
getGenericSignatureForFunctionComponent(), selfType);
11791172
}
11801173

1181-
static CanType getResultType(CanType type) {
1182-
return cast<AnyFunctionType>(type).getResult();
1174+
static CanType getResultType(CanType type, bool withoutYields) {
1175+
auto aft = cast<AnyFunctionType>(type);
1176+
if (withoutYields)
1177+
aft = CanAnyFunctionType(aft->getWithoutYields());
1178+
1179+
return aft.getResult();
11831180
}
11841181

1185-
AbstractionPattern AbstractionPattern::getFunctionResultType() const {
1182+
AbstractionPattern AbstractionPattern::getFunctionResultType(bool withoutYields) const {
11861183
switch (getKind()) {
11871184
case Kind::Invalid:
11881185
llvm_unreachable("querying invalid abstraction pattern!");
@@ -1196,7 +1193,7 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
11961193
return AbstractionPattern::getOpaque();
11971194
return AbstractionPattern(getGenericSubstitutions(),
11981195
getGenericSignatureForFunctionComponent(),
1199-
getResultType(getType()));
1196+
getResultType(getType(), withoutYields));
12001197
case Kind::Discard:
12011198
llvm_unreachable("don't need to discard function abstractions yet");
12021199
case Kind::ClangType:
@@ -1205,33 +1202,34 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12051202
auto clangFunctionType = getClangFunctionType(getClangType());
12061203
return AbstractionPattern(getGenericSubstitutions(),
12071204
getGenericSignatureForFunctionComponent(),
1208-
getResultType(getType()),
1205+
getResultType(getType(), withoutYields),
12091206
clangFunctionType->getReturnType().getTypePtr());
12101207
}
12111208
case Kind::CXXMethodType:
12121209
case Kind::PartialCurriedCXXMethodType:
12131210
return AbstractionPattern(getGenericSubstitutions(),
12141211
getGenericSignatureForFunctionComponent(),
1215-
getResultType(getType()),
1212+
getResultType(getType(), withoutYields),
12161213
getCXXMethod()->getReturnType().getTypePtr());
12171214
case Kind::CurriedObjCMethodType:
12181215
return getPartialCurriedObjCMethod(
12191216
getGenericSubstitutions(),
12201217
getGenericSignatureForFunctionComponent(),
1221-
getResultType(getType()),
1218+
getResultType(getType(), withoutYields),
12221219
getObjCMethod(),
12231220
getEncodedForeignInfo());
12241221
case Kind::CurriedCFunctionAsMethodType:
12251222
return getPartialCurriedCFunctionAsMethod(
12261223
getGenericSubstitutions(),
12271224
getGenericSignatureForFunctionComponent(),
1228-
getResultType(getType()),
1225+
getResultType(getType(), withoutYields),
12291226
getClangType(),
12301227
getImportAsMemberStatus());
12311228
case Kind::CurriedCXXMethodType:
12321229
return getPartialCurriedCXXMethod(getGenericSubstitutions(),
12331230
getGenericSignatureForFunctionComponent(),
1234-
getResultType(getType()), getCXXMethod(),
1231+
getResultType(getType(), withoutYields),
1232+
getCXXMethod(),
12351233
getImportAsMemberStatus());
12361234
case Kind::PartialCurriedObjCMethodType:
12371235
case Kind::ObjCMethodType: {
@@ -1288,7 +1286,8 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12881286

12891287
return AbstractionPattern(getGenericSubstitutions(),
12901288
getGenericSignatureForFunctionComponent(),
1291-
getResultType(getType()), clangResultType);
1289+
getResultType(getType(), withoutYields),
1290+
clangResultType);
12921291
}
12931292

12941293
default:
@@ -1298,14 +1297,15 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12981297
return AbstractionPattern::getObjCCompletionHandlerArgumentsType(
12991298
getGenericSubstitutions(),
13001299
getGenericSignatureForFunctionComponent(),
1301-
getResultType(getType()), callbackParamTy,
1300+
getResultType(getType(), withoutYields),
1301+
callbackParamTy,
13021302
getEncodedForeignInfo());
13031303
}
13041304
}
13051305

13061306
return AbstractionPattern(getGenericSubstitutions(),
13071307
getGenericSignatureForFunctionComponent(),
1308-
getResultType(getType()),
1308+
getResultType(getType(), withoutYields),
13091309
getObjCMethod()->getReturnType().getTypePtr());
13101310
}
13111311
case Kind::OpaqueFunction:
@@ -2738,13 +2738,6 @@ class SubstFunctionTypePatternVisitor
27382738
llvm_unreachable("shouldn't encounter pack element by itself");
27392739
}
27402740

2741-
CanType visitYieldResultType(CanYieldResultType yield,
2742-
AbstractionPattern pattern) {
2743-
auto resultType = visit(yield.getResultType(), pattern.getYieldResultType());
2744-
return YieldResultType::get(resultType, yield->isInOut())
2745-
->getCanonicalType();
2746-
}
2747-
27482741
CanType handlePackExpansion(AbstractionPattern origExpansion,
27492742
CanType candidateSubstType) {
27502743
// When we're within a pack expansion, pack references matching that
@@ -2921,10 +2914,9 @@ class SubstFunctionTypePatternVisitor
29212914
addParam(param.getOrigFlags(), expansionType);
29222915
}
29232916
});
2924-
2925-
if (yieldType) {
2917+
2918+
if (yieldType)
29262919
substYieldType = visit(yieldType, yieldPattern);
2927-
}
29282920

29292921
CanType newErrorType;
29302922

@@ -2934,8 +2926,8 @@ class SubstFunctionTypePatternVisitor
29342926
newErrorType = visit(errorType, errorPattern);
29352927
}
29362928

2937-
auto newResultTy = visit(func.getResult(),
2938-
pattern.getFunctionResultType());
2929+
auto newResultTy = visit(func->getWithoutYields()->getResult()->getCanonicalType(),
2930+
pattern.getFunctionResultType(/* withoutYields */ true));
29392931

29402932
std::optional<FunctionType::ExtInfo> extInfo;
29412933
if (func->hasExtInfo())
@@ -2947,6 +2939,10 @@ class SubstFunctionTypePatternVisitor
29472939
extInfo = extInfo->withThrows(true, newErrorType);
29482940
}
29492941

2942+
// Yields were substituted separately
2943+
if (extInfo)
2944+
extInfo = extInfo->withCoroutine(false);
2945+
29502946
return CanFunctionType::get(FunctionType::CanParamArrayRef(newParams),
29512947
newResultTy, extInfo);
29522948
}

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,8 +2528,7 @@ static CanSILFunctionType getSILFunctionType(
25282528
coroutineKind = SILCoroutineKind::YieldOnce;
25292529
}
25302530

2531-
// Coroutine accessors are always native, so fetch the native
2532-
// abstraction pattern.
2531+
// Coroutines are always native, so fetch the native abstraction pattern.
25332532
auto sig = origFd->getGenericSignatureOfContext()
25342533
.getCanonicalSignature();
25352534
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();

test/SILGen/coroutine_subst_function_types.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ extension ConcreteWithInt : ProtoWithAssoc {
114114
}
115115

116116
// CHECK-LABEL: sil_vtable ConcreteWithInt {
117-
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
118-
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
119-
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
120-
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
121-
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
117+
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields T : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
118+
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields () -> T : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
119+
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields T : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
120+
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields U : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
121+
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields (T?, [String : T]) : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
122122
// CHECK: }

test/api-digester/Outputs/cake-abi.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2357,4 +2357,4 @@
23572357
],
23582358
"json_format_version": 8
23592359
}
2360-
}
2360+
}

test/api-digester/stability-concurrency-abi.test

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,44 @@ Func withTaskGroup(of:returning:body:) has mangled name changing from '_Concurre
130130

131131
Func pthread_main_np() is a new API without '@available'
132132

133+
Accessor AsyncCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
134+
Accessor AsyncDropFirstSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
135+
Accessor AsyncDropFirstSequence.Iterator.count.Modify() has return type change from () to inout @yields Swift.Int
136+
Accessor AsyncDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
137+
Accessor AsyncDropWhileSequence.Iterator.predicate.Modify() has return type change from () to inout @yields ((τ_0_0.Element) async -> Swift.Bool)?
138+
Accessor AsyncFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
139+
Accessor AsyncFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
140+
Accessor AsyncFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
141+
Accessor AsyncFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
142+
Accessor AsyncMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
143+
Accessor AsyncPrefixSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
144+
Accessor AsyncPrefixSequence.Iterator.remaining.Modify() has return type change from () to inout @yields Swift.Int
145+
Accessor AsyncPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
146+
Accessor AsyncPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
147+
Accessor AsyncStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncStream<τ_0_0>.Continuation.Termination) -> ())?
148+
Accessor AsyncThrowingCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
149+
Accessor AsyncThrowingCompactMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
150+
Accessor AsyncThrowingDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
151+
Accessor AsyncThrowingDropWhileSequence.Iterator.doneDropping.Modify() has return type change from () to inout @yields Swift.Bool
152+
Accessor AsyncThrowingDropWhileSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
153+
Accessor AsyncThrowingFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
154+
Accessor AsyncThrowingFilterSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
155+
Accessor AsyncThrowingFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
156+
Accessor AsyncThrowingFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
157+
Accessor AsyncThrowingFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
158+
Accessor AsyncThrowingMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
159+
Accessor AsyncThrowingMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
160+
Accessor AsyncThrowingPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
161+
Accessor AsyncThrowingPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
162+
Accessor AsyncThrowingStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncThrowingStream<τ_0_0, τ_0_1>.Continuation.Termination) -> ())?
163+
Accessor TaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
164+
Accessor TaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.TaskGroup<τ_0_0>
165+
Accessor TaskPriority.rawValue.Modify() has return type change from () to inout @yields Swift.UInt8
166+
Accessor ThrowingTaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
167+
Accessor ThrowingTaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.ThrowingTaskGroup<τ_0_0, τ_0_1>
168+
Accessor UnownedSerialExecutor.executor.Modify() has return type change from () to inout @yields Builtin.Executor
169+
Accessor UnsafeContinuation.context.Modify() has return type change from () to inout @yields Builtin.RawUnsafeContinuation
170+
133171
// *** DO NOT DISABLE OR XFAIL THIS TEST. *** (See comment above.)
134172

135173

0 commit comments

Comments
 (0)