Skip to content

Commit e658e86

Browse files
committed
Properly substitute coroutines
1 parent 4f702fb commit e658e86

File tree

9 files changed

+110
-45
lines changed

9 files changed

+110
-45
lines changed

include/swift/AST/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3828,6 +3828,9 @@ class AnyFunctionType : public TypeBase {
38283828
/// Return the function type without the throwing.
38293829
AnyFunctionType *getWithoutThrowing() const;
38303830

3831+
/// Return the function type without yields (and coroutine flag)
3832+
AnyFunctionType *getWithoutYields() const;
3833+
38313834
/// True if the parameter declaration it is attached to is guaranteed
38323835
/// to not persist the closure for longer than the duration of the call.
38333836
bool isNoEscape() const {

include/swift/SIL/AbstractionPattern.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,13 +1524,9 @@ class AbstractionPattern {
15241524
/// type, return the abstraction pattern for one of its argument types.
15251525
AbstractionPattern getParameterizedProtocolArgType(unsigned i) const;
15261526

1527-
/// Given that the value being abstracted is a yield result type,
1528-
/// return the abstraction pattern for corresponding type.
1529-
AbstractionPattern getYieldResultType() const;
1530-
15311527
/// Given that the value being abstracted is a function, return the
15321528
/// abstraction pattern for its result type.
1533-
AbstractionPattern getFunctionResultType() const;
1529+
AbstractionPattern getFunctionResultType(bool withoutYields = false) const;
15341530

15351531
/// Given that the value being abstracted is a function, return the
15361532
/// abstraction pattern for its thrown error type.

lib/AST/Decl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10813,7 +10813,7 @@ Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
1081310813
Type eltTy = elt.getType();
1081410814
if (eltTy->is<YieldResultType>())
1081510815
continue;
10816-
elements.push_back(eltTy);
10816+
elements.push_back(elt);
1081710817
}
1081810818

1081910819
// Handle vanishing tuples -- flatten to produce the

lib/AST/Type.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4461,6 +4461,39 @@ AnyFunctionType *AnyFunctionType::getWithoutThrowing() const {
44614461
return withExtInfo(info);
44624462
}
44634463

4464+
AnyFunctionType *AnyFunctionType::getWithoutYields() const {
4465+
auto resultType = getResult();
4466+
4467+
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
4468+
// Strip @yield results on the first level of tuple
4469+
SmallVector<TupleTypeElt, 4> elements;
4470+
for (const auto &elt : tupleResTy->getElements()) {
4471+
Type eltTy = elt.getType();
4472+
if (eltTy->is<YieldResultType>())
4473+
continue;
4474+
elements.push_back(elt);
4475+
}
4476+
4477+
// Handle vanishing tuples -- flatten to produce the
4478+
// normal coroutine result type
4479+
if (elements.size() == 1 && isCoroutine())
4480+
resultType = elements[0].getType();
4481+
else
4482+
resultType = TupleType::get(elements, getASTContext());
4483+
} else if (resultType->is<YieldResultType>()) {
4484+
resultType = TupleType::getEmpty(getASTContext());
4485+
}
4486+
4487+
auto noCoroExtInfo = getExtInfo().intoBuilder()
4488+
.withCoroutine(false)
4489+
.build();
4490+
if (isa<FunctionType>(this))
4491+
return FunctionType::get(getParams(), resultType, noCoroExtInfo);
4492+
assert(isa<GenericFunctionType>(this));
4493+
return GenericFunctionType::get(getOptGenericSignature(), getParams(),
4494+
resultType, noCoroExtInfo);
4495+
}
4496+
44644497
std::optional<Type> AnyFunctionType::getEffectiveThrownErrorType() const {
44654498
// A non-throwing function... has no thrown interface type.
44664499
if (!isThrowing())

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,13 +1038,6 @@ AbstractionPattern::getParameterizedProtocolArgType(unsigned argIndex) const {
10381038
cast<ParameterizedProtocolType>(getType()).getArgs()[argIndex]);
10391039
}
10401040

1041-
AbstractionPattern AbstractionPattern::getYieldResultType() const {
1042-
assert(getKind() == Kind::Type);
1043-
return AbstractionPattern(getGenericSubstitutions(),
1044-
getGenericSignature(),
1045-
cast<YieldResultType>(getType()).getResultType());
1046-
}
1047-
10481041
AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
10491042
switch (getKind()) {
10501043
case Kind::Invalid:
@@ -1178,11 +1171,15 @@ AbstractionPattern::getCXXMethodSelfPattern(CanType selfType) const {
11781171
getGenericSignatureForFunctionComponent(), selfType);
11791172
}
11801173

1181-
static CanType getResultType(CanType type) {
1182-
return cast<AnyFunctionType>(type).getResult();
1174+
static CanType getResultType(CanType type, bool withoutYields) {
1175+
auto aft = cast<AnyFunctionType>(type);
1176+
if (withoutYields)
1177+
aft = CanAnyFunctionType(aft->getWithoutYields());
1178+
1179+
return aft.getResult();
11831180
}
11841181

1185-
AbstractionPattern AbstractionPattern::getFunctionResultType() const {
1182+
AbstractionPattern AbstractionPattern::getFunctionResultType(bool withoutYields) const {
11861183
switch (getKind()) {
11871184
case Kind::Invalid:
11881185
llvm_unreachable("querying invalid abstraction pattern!");
@@ -1196,7 +1193,7 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
11961193
return AbstractionPattern::getOpaque();
11971194
return AbstractionPattern(getGenericSubstitutions(),
11981195
getGenericSignatureForFunctionComponent(),
1199-
getResultType(getType()));
1196+
getResultType(getType(), withoutYields));
12001197
case Kind::Discard:
12011198
llvm_unreachable("don't need to discard function abstractions yet");
12021199
case Kind::ClangType:
@@ -1205,33 +1202,34 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12051202
auto clangFunctionType = getClangFunctionType(getClangType());
12061203
return AbstractionPattern(getGenericSubstitutions(),
12071204
getGenericSignatureForFunctionComponent(),
1208-
getResultType(getType()),
1205+
getResultType(getType(), withoutYields),
12091206
clangFunctionType->getReturnType().getTypePtr());
12101207
}
12111208
case Kind::CXXMethodType:
12121209
case Kind::PartialCurriedCXXMethodType:
12131210
return AbstractionPattern(getGenericSubstitutions(),
12141211
getGenericSignatureForFunctionComponent(),
1215-
getResultType(getType()),
1212+
getResultType(getType(), withoutYields),
12161213
getCXXMethod()->getReturnType().getTypePtr());
12171214
case Kind::CurriedObjCMethodType:
12181215
return getPartialCurriedObjCMethod(
12191216
getGenericSubstitutions(),
12201217
getGenericSignatureForFunctionComponent(),
1221-
getResultType(getType()),
1218+
getResultType(getType(), withoutYields),
12221219
getObjCMethod(),
12231220
getEncodedForeignInfo());
12241221
case Kind::CurriedCFunctionAsMethodType:
12251222
return getPartialCurriedCFunctionAsMethod(
12261223
getGenericSubstitutions(),
12271224
getGenericSignatureForFunctionComponent(),
1228-
getResultType(getType()),
1225+
getResultType(getType(), withoutYields),
12291226
getClangType(),
12301227
getImportAsMemberStatus());
12311228
case Kind::CurriedCXXMethodType:
12321229
return getPartialCurriedCXXMethod(getGenericSubstitutions(),
12331230
getGenericSignatureForFunctionComponent(),
1234-
getResultType(getType()), getCXXMethod(),
1231+
getResultType(getType(), withoutYields),
1232+
getCXXMethod(),
12351233
getImportAsMemberStatus());
12361234
case Kind::PartialCurriedObjCMethodType:
12371235
case Kind::ObjCMethodType: {
@@ -1288,7 +1286,8 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12881286

12891287
return AbstractionPattern(getGenericSubstitutions(),
12901288
getGenericSignatureForFunctionComponent(),
1291-
getResultType(getType()), clangResultType);
1289+
getResultType(getType(), withoutYields),
1290+
clangResultType);
12921291
}
12931292

12941293
default:
@@ -1298,14 +1297,15 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
12981297
return AbstractionPattern::getObjCCompletionHandlerArgumentsType(
12991298
getGenericSubstitutions(),
13001299
getGenericSignatureForFunctionComponent(),
1301-
getResultType(getType()), callbackParamTy,
1300+
getResultType(getType(), withoutYields),
1301+
callbackParamTy,
13021302
getEncodedForeignInfo());
13031303
}
13041304
}
13051305

13061306
return AbstractionPattern(getGenericSubstitutions(),
13071307
getGenericSignatureForFunctionComponent(),
1308-
getResultType(getType()),
1308+
getResultType(getType(), withoutYields),
13091309
getObjCMethod()->getReturnType().getTypePtr());
13101310
}
13111311
case Kind::OpaqueFunction:
@@ -2722,13 +2722,6 @@ class SubstFunctionTypePatternVisitor
27222722
llvm_unreachable("shouldn't encounter pack element by itself");
27232723
}
27242724

2725-
CanType visitYieldResultType(CanYieldResultType yield,
2726-
AbstractionPattern pattern) {
2727-
auto resultType = visit(yield.getResultType(), pattern.getYieldResultType());
2728-
return YieldResultType::get(resultType, yield->isInOut())
2729-
->getCanonicalType();
2730-
}
2731-
27322725
CanType handlePackExpansion(AbstractionPattern origExpansion,
27332726
CanType candidateSubstType) {
27342727
// When we're within a pack expansion, pack references matching that
@@ -2905,10 +2898,9 @@ class SubstFunctionTypePatternVisitor
29052898
addParam(param.getOrigFlags(), expansionType);
29062899
}
29072900
});
2908-
2909-
if (yieldType) {
2901+
2902+
if (yieldType)
29102903
substYieldType = visit(yieldType, yieldPattern);
2911-
}
29122904

29132905
CanType newErrorType;
29142906

@@ -2918,8 +2910,8 @@ class SubstFunctionTypePatternVisitor
29182910
newErrorType = visit(errorType, errorPattern);
29192911
}
29202912

2921-
auto newResultTy = visit(func.getResult(),
2922-
pattern.getFunctionResultType());
2913+
auto newResultTy = visit(func->getWithoutYields()->getResult()->getCanonicalType(),
2914+
pattern.getFunctionResultType(/* withoutYields */ true));
29232915

29242916
std::optional<FunctionType::ExtInfo> extInfo;
29252917
if (func->hasExtInfo())
@@ -2931,6 +2923,10 @@ class SubstFunctionTypePatternVisitor
29312923
extInfo = extInfo->withThrows(true, newErrorType);
29322924
}
29332925

2926+
// Yields were substituted separately
2927+
if (extInfo)
2928+
extInfo = extInfo->withCoroutine(false);
2929+
29342930
return CanFunctionType::get(FunctionType::CanParamArrayRef(newParams),
29352931
newResultTy, extInfo);
29362932
}

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,8 +2392,7 @@ static CanSILFunctionType getSILFunctionType(
23922392
coroutineKind = SILCoroutineKind::YieldOnce;
23932393
}
23942394

2395-
// Coroutine accessors are always native, so fetch the native
2396-
// abstraction pattern.
2395+
// Coroutines are always native, so fetch the native abstraction pattern.
23972396
auto sig = origFd->getGenericSignatureOfContext()
23982397
.getCanonicalSignature();
23992398
auto origYieldType = origFd->getYieldsInterfaceType()->castTo<YieldResultType>();

test/SILGen/coroutine_subst_function_types.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ extension ConcreteWithInt : ProtoWithAssoc {
114114
}
115115

116116
// CHECK-LABEL: sil_vtable ConcreteWithInt {
117-
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
118-
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
119-
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
120-
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
121-
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
117+
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields T : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
118+
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields () -> T : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
119+
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields T : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
120+
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields U : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
121+
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields (T?, [String : T]) : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
122122
// CHECK: }

test/api-digester/Outputs/cake-abi.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2336,4 +2336,4 @@
23362336
],
23372337
"json_format_version": 8
23382338
}
2339-
}
2339+
}

test/api-digester/stability-concurrency-abi.test

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,44 @@ Func withThrowingTaskGroup(of:returning:body:) has parameter 2 type change from
125125
Func withTaskGroup(of:returning:body:) has been renamed to Func withTaskGroup(of:returning:isolation:body:)
126126
Func withTaskGroup(of:returning:body:) has mangled name changing from '_Concurrency.withTaskGroup<A, B where A: Swift.Sendable>(of: A.Type, returning: B.Type, body: (inout Swift.TaskGroup<A>) async -> B) async -> B' to '_Concurrency.withTaskGroup<A, B where A: Swift.Sendable>(of: A.Type, returning: B.Type, isolation: isolated Swift.Optional<Swift.Actor>, body: (inout Swift.TaskGroup<A>) async -> B) async -> B'
127127

128+
Accessor AsyncCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
129+
Accessor AsyncDropFirstSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
130+
Accessor AsyncDropFirstSequence.Iterator.count.Modify() has return type change from () to inout @yields Swift.Int
131+
Accessor AsyncDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
132+
Accessor AsyncDropWhileSequence.Iterator.predicate.Modify() has return type change from () to inout @yields ((τ_0_0.Element) async -> Swift.Bool)?
133+
Accessor AsyncFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
134+
Accessor AsyncFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
135+
Accessor AsyncFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
136+
Accessor AsyncFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
137+
Accessor AsyncMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
138+
Accessor AsyncPrefixSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
139+
Accessor AsyncPrefixSequence.Iterator.remaining.Modify() has return type change from () to inout @yields Swift.Int
140+
Accessor AsyncPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
141+
Accessor AsyncPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
142+
Accessor AsyncStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncStream<τ_0_0>.Continuation.Termination) -> ())?
143+
Accessor AsyncThrowingCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
144+
Accessor AsyncThrowingCompactMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
145+
Accessor AsyncThrowingDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
146+
Accessor AsyncThrowingDropWhileSequence.Iterator.doneDropping.Modify() has return type change from () to inout @yields Swift.Bool
147+
Accessor AsyncThrowingDropWhileSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
148+
Accessor AsyncThrowingFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
149+
Accessor AsyncThrowingFilterSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
150+
Accessor AsyncThrowingFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
151+
Accessor AsyncThrowingFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
152+
Accessor AsyncThrowingFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
153+
Accessor AsyncThrowingMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
154+
Accessor AsyncThrowingMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
155+
Accessor AsyncThrowingPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
156+
Accessor AsyncThrowingPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
157+
Accessor AsyncThrowingStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncThrowingStream<τ_0_0, τ_0_1>.Continuation.Termination) -> ())?
158+
Accessor TaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
159+
Accessor TaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.TaskGroup<τ_0_0>
160+
Accessor TaskPriority.rawValue.Modify() has return type change from () to inout @yields Swift.UInt8
161+
Accessor ThrowingTaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
162+
Accessor ThrowingTaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.ThrowingTaskGroup<τ_0_0, τ_0_1>
163+
Accessor UnownedSerialExecutor.executor.Modify() has return type change from () to inout @yields Builtin.Executor
164+
Accessor UnsafeContinuation.context.Modify() has return type change from () to inout @yields Builtin.RawUnsafeContinuation
165+
128166
// *** DO NOT DISABLE OR XFAIL THIS TEST. *** (See comment above.)
129167

130168

0 commit comments

Comments
 (0)