Skip to content

Commit d124554

Browse files
committed
Add basic boilerplate for AST coroutines and yields
1 parent f244c8f commit d124554

28 files changed

+203
-9
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7866,6 +7866,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
78667866
/// attribute.
78677867
bool isTransparent() const;
78687868

7869+
bool isCoroutine() const;
7870+
78697871
// Expose our import as member status
78707872
ImportAsMemberStatus getImportAsMemberStatus() const {
78717873
return ImportAsMemberStatus(Bits.AbstractFunctionDecl.IAMStatus);

include/swift/AST/DeclAttr.def

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,12 @@ DECL_ATTR(specialized, Specialized,
899899
AllowMultipleAttributes | LongAttribute | UserInaccessible | ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove | ForbiddenInABIAttr,
900900
172)
901901

902-
LAST_DECL_ATTR(Specialized)
902+
SIMPLE_DECL_ATTR(yield_once, Coroutine,
903+
OnFunc,
904+
UserInaccessible | ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove | EquivalentInABIAttr,
905+
173)
906+
907+
LAST_DECL_ATTR(Coroutine)
903908

904909
#undef DECL_ATTR_ALIAS
905910
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/ExtInfo.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ enum class SILFunctionTypeRepresentation : uint8_t {
323323
CFunctionPointer = uint8_t(FunctionTypeRepresentation::CFunctionPointer),
324324

325325
/// The value of the greatest AST function representation.
326-
LastAST = CFunctionPointer,
326+
LastAST = uint8_t(FunctionTypeRepresentation::Last),
327327

328328
/// The value of the least SIL-only function representation.
329329
FirstSIL = 8,
@@ -523,8 +523,8 @@ class ASTExtInfoBuilder {
523523
// If bits are added or removed, then TypeBase::NumAFTExtInfoBits
524524
// and NumMaskBits must be updated, and they must match.
525525
//
526-
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult |
527-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 |
526+
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult | coroutine |
527+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 | 15 |
528528
//
529529
enum : unsigned {
530530
RepresentationMask = 0xF << 0,
@@ -537,7 +537,8 @@ class ASTExtInfoBuilder {
537537
DifferentiabilityMaskOffset = 11,
538538
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
539539
SendingResultMask = 1 << 14,
540-
NumMaskBits = 15
540+
CoroutineMask = 1 << 15,
541+
NumMaskBits = 16
541542
};
542543

543544
static_assert(FunctionTypeIsolation::Mask == 0x7, "update mask manually");
@@ -616,6 +617,8 @@ class ASTExtInfoBuilder {
616617

617618
constexpr bool hasSendingResult() const { return bits & SendingResultMask; }
618619

620+
constexpr bool isCoroutine() const { return bits & CoroutineMask; }
621+
619622
constexpr DifferentiabilityKind getDifferentiabilityKind() const {
620623
return DifferentiabilityKind((bits & DifferentiabilityMask) >>
621624
DifferentiabilityMaskOffset);
@@ -732,6 +735,13 @@ class ASTExtInfoBuilder {
732735
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
733736
}
734737

738+
[[nodiscard]]
739+
ASTExtInfoBuilder withCoroutine(bool coroutine = true) const {
740+
return ASTExtInfoBuilder(
741+
coroutine ? (bits | CoroutineMask) : (bits & ~CoroutineMask),
742+
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
743+
}
744+
735745
[[nodiscard]]
736746
ASTExtInfoBuilder
737747
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
@@ -854,6 +864,8 @@ class ASTExtInfo {
854864

855865
constexpr bool isThrowing() const { return builder.isThrowing(); }
856866

867+
constexpr bool isCoroutine() const { return builder.isCoroutine(); }
868+
857869
constexpr bool hasSendingResult() const { return builder.hasSendingResult(); }
858870

859871
constexpr DifferentiabilityKind getDifferentiabilityKind() const {
@@ -917,6 +929,14 @@ class ASTExtInfo {
917929
return builder.withThrows(true, Type()).build();
918930
}
919931

932+
/// Helper method for changing only the coroutine field.
933+
///
934+
/// Prefer using \c ASTExtInfoBuilder::withCoroutine for chaining.
935+
[[nodiscard]]
936+
ASTExtInfo withCoroutine(bool coroutine = true) const {
937+
return builder.withCoroutine(coroutine).build();
938+
}
939+
920940
/// Helper method for changing only the async field.
921941
///
922942
/// Prefer using \c ASTExtInfoBuilder::withAsync for chaining.

include/swift/AST/TypeDifferenceVisitor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ class CanTypeDifferenceVisitor : public CanTypePairVisitor<Impl, bool> {
172172
type1->getElements(), type2->getElements());
173173
}
174174

175+
bool visitYieldResultType(CanYieldResultType type1, CanYieldResultType type2) {
176+
return asImpl().visit(type1.getResultType(), type2.getResultType());
177+
}
178+
175179
bool visitComponent(CanType type1, CanType type2,
176180
const TupleTypeElt &elt1, const TupleTypeElt &elt2) {
177181
if (elt1.getName() != elt2.getName())

include/swift/AST/TypeMatcher.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ class TypeMatcher {
153153
return mismatch(firstTuple.getPointer(), secondType, sugaredFirstType);
154154
}
155155

156+
bool visitYieldResultType(CanYieldResultType firstType, Type secondType,
157+
Type sugaredFirstType) {
158+
if (auto secondYieldType = secondType->getAs<YieldResultType>())
159+
if (!this->visit(firstType.getResultType(),
160+
secondYieldType->getResultType(),
161+
sugaredFirstType->getAs<YieldResultType>()->getResultType()))
162+
return false;
163+
164+
return mismatch(firstType.getPointer(), secondType, sugaredFirstType);
165+
}
166+
156167
bool visitSILPackType(CanSILPackType firstPack, Type secondType,
157168
Type sugaredFirstType) {
158169
if (auto secondPack = secondType->getAs<SILPackType>()) {

include/swift/AST/TypeNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ TYPE(InOut, Type)
205205
TYPE(Pack, Type)
206206
TYPE(PackExpansion, Type)
207207
TYPE(PackElement, Type)
208+
TYPE(YieldResult, Type)
208209
UNCHECKED_TYPE(TypeVariable, Type)
209210
UNCHECKED_TYPE(ErrorUnion, Type)
210211
TYPE(Integer, Type)

include/swift/AST/TypeTransform.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,16 @@ case TypeKind::Id:
10401040
t : InOutType::get(objectTy);
10411041
}
10421042

1043+
case TypeKind::YieldResult: {
1044+
auto yield = cast<YieldResultType>(base);
1045+
auto objectTy = doIt(yield->getResultType(), TypePosition::Invariant);
1046+
if (!objectTy || objectTy->hasError())
1047+
return objectTy;
1048+
1049+
return objectTy.getPointer() == yield->getResultType().getPointer() ?
1050+
t : YieldResultType::get(objectTy, yield->isInOut());
1051+
}
1052+
10431053
case TypeKind::Existential: {
10441054
auto *existential = cast<ExistentialType>(base);
10451055
auto constraint = doIt(existential->getConstraintType(), pos);

include/swift/AST/Types.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
407407
}
408408

409409
protected:
410-
enum { NumAFTExtInfoBits = 15 };
410+
enum { NumAFTExtInfoBits = 16 };
411411
enum { NumSILExtInfoBits = 14 };
412412

413413
// clang-format off
@@ -437,15 +437,15 @@ class alignas(1 << TypeAlignInBits) TypeBase
437437
HasCachedType : 1
438438
);
439439

440-
SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+16,
440+
SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+14,
441441
/// Extra information which affects how the function is called, like
442442
/// regparm and the calling convention.
443443
ExtInfoBits : NumAFTExtInfoBits,
444444
HasExtInfo : 1,
445445
HasClangTypeInfo : 1,
446446
HasThrownError : 1,
447447
HasLifetimeDependencies : 1,
448-
NumParams : 15
448+
NumParams : 14
449449
);
450450

451451
SWIFT_INLINE_BITFIELD_FULL(ArchetypeType, TypeBase, 1+1+16,
@@ -1702,6 +1702,33 @@ class UnresolvedType : public TypeBase {
17021702
};
17031703
DEFINE_EMPTY_CAN_TYPE_WRAPPER(UnresolvedType, Type)
17041704

1705+
class YieldResultType : public TypeBase {
1706+
Type ResultType;
1707+
bool InOut = false;
1708+
1709+
YieldResultType(Type objectTy, bool InOut, const ASTContext *canonicalContext,
1710+
RecursiveTypeProperties properties)
1711+
: TypeBase(TypeKind::YieldResult, canonicalContext, properties),
1712+
ResultType(objectTy), InOut(InOut) {}
1713+
1714+
public:
1715+
static YieldResultType *get(Type originalType, bool InOut);
1716+
1717+
Type getResultType() const { return ResultType; }
1718+
bool isInOut() const { return InOut; }
1719+
1720+
// Implement isa/cast/dyncast/etc.
1721+
static bool classof(const TypeBase *T) {
1722+
return T->getKind() == TypeKind::YieldResult;
1723+
}
1724+
};
1725+
1726+
BEGIN_CAN_TYPE_WRAPPER(YieldResultType, Type)
1727+
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1728+
static CanYieldResultType get(CanType type, bool InOut) {
1729+
return CanYieldResultType(YieldResultType::get(type, InOut));
1730+
}
1731+
END_CAN_TYPE_WRAPPER(YieldResultType, Type)
17051732

17061733
/// BuiltinType - An abstract class for all the builtin types.
17071734
class BuiltinType : public TypeBase {
@@ -3886,6 +3913,8 @@ class AnyFunctionType : public TypeBase {
38863913

38873914
bool isThrowing() const { return getExtInfo().isThrowing(); }
38883915

3916+
bool isCoroutine() const { return getExtInfo().isCoroutine(); }
3917+
38893918
bool hasSendingResult() const { return getExtInfo().hasSendingResult(); }
38903919

38913920
bool hasEffect(EffectKind kind) const;

lib/AST/ASTContext.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@ struct ASTContext::Implementation {
555555
llvm::DenseMap<uintptr_t, ReferenceStorageType*> ReferenceStorageTypes;
556556
llvm::DenseMap<Type, LValueType*> LValueTypes;
557557
llvm::DenseMap<Type, InOutType*> InOutTypes;
558+
llvm::DenseMap<llvm::PointerIntPair<TypeBase*, 1, bool>,
559+
YieldResultType*> YieldResultTypes;
558560
llvm::DenseMap<std::pair<Type, void*>, DependentMemberType *>
559561
DependentMemberTypes;
560562
llvm::FoldingSet<ErrorUnionType> ErrorUnionTypes;
@@ -3336,6 +3338,7 @@ size_t ASTContext::Implementation::Arena::getTotalMemory() const {
33363338
llvm::capacity_in_bytes(ReferenceStorageTypes) +
33373339
llvm::capacity_in_bytes(LValueTypes) +
33383340
llvm::capacity_in_bytes(InOutTypes) +
3341+
llvm::capacity_in_bytes(YieldResultTypes) +
33393342
llvm::capacity_in_bytes(DependentMemberTypes) +
33403343
llvm::capacity_in_bytes(EnumTypes) +
33413344
llvm::capacity_in_bytes(StructTypes) +
@@ -3374,6 +3377,7 @@ void ASTContext::Implementation::Arena::dump(llvm::raw_ostream &os) const {
33743377
SIZE_AND_BYTES(ReferenceStorageTypes);
33753378
SIZE_AND_BYTES(LValueTypes);
33763379
SIZE_AND_BYTES(InOutTypes);
3380+
SIZE_AND_BYTES(YieldResultTypes);
33773381
SIZE_AND_BYTES(DependentMemberTypes);
33783382
SIZE(ErrorUnionTypes);
33793383
SIZE_AND_BYTES(PlaceholderTypes);
@@ -5683,6 +5687,28 @@ InOutType *InOutType::get(Type objectTy) {
56835687
properties);
56845688
}
56855689

5690+
YieldResultType *YieldResultType::get(Type objectTy, bool InOut) {
5691+
auto properties = objectTy->getRecursiveProperties();
5692+
if (InOut) {
5693+
assert(!objectTy->is<LValueType>() && !objectTy->is<InOutType>() &&
5694+
"cannot have 'inout' or @lvalue wrapped inside an 'inout yield'");
5695+
properties &= ~RecursiveTypeProperties::IsLValue;
5696+
}
5697+
5698+
auto arena = getArena(properties);
5699+
5700+
auto &C = objectTy->getASTContext();
5701+
auto pair = llvm::PointerIntPair<TypeBase*, 1, bool>(objectTy.getPointer(),
5702+
InOut);
5703+
auto &entry = C.getImpl().getArena(arena).YieldResultTypes[pair];
5704+
if (entry)
5705+
return entry;
5706+
5707+
const ASTContext *canonicalContext = objectTy->isCanonical() ? &C : nullptr;
5708+
return entry = new (C, arena) YieldResultType(objectTy, InOut, canonicalContext,
5709+
properties);
5710+
}
5711+
56865712
DependentMemberType *DependentMemberType::get(Type base, Identifier name) {
56875713
auto properties = base->getRecursiveProperties();
56885714
properties |= RecursiveTypeProperties::HasDependentMember;

lib/AST/ASTDumper.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4967,6 +4967,7 @@ class PrintAttribute : public AttributeVisitor<PrintAttribute, void, Label>,
49674967
TRIVIAL_ATTR_PRINTER(CompilerInitialized, compiler_initialized)
49684968
TRIVIAL_ATTR_PRINTER(Consuming, consuming)
49694969
TRIVIAL_ATTR_PRINTER(Convenience, convenience)
4970+
TRIVIAL_ATTR_PRINTER(Coroutine, coroutine)
49704971
TRIVIAL_ATTR_PRINTER(DiscardableResult, discardable_result)
49714972
TRIVIAL_ATTR_PRINTER(DisfavoredOverload, disfavored_overload)
49724973
TRIVIAL_ATTR_PRINTER(DistributedActor, distributed_actor)
@@ -6065,6 +6066,13 @@ namespace {
60656066
printFoot();
60666067
}
60676068

6069+
void visitYieldResultType(YieldResultType *T, Label label) {
6070+
printCommon("yield", label);
6071+
printFlag(T->isInOut(), "inout");
6072+
printRec(T->getResultType(), Label::always("type"));
6073+
printFoot();
6074+
}
6075+
60686076
TRIVIAL_TYPE_PRINTER(Unresolved, unresolved)
60696077

60706078
void visitPlaceholderType(PlaceholderType *T, Label label) {
@@ -6485,6 +6493,7 @@ namespace {
64856493
printFlag(T->isAsync(), "async");
64866494
printFlag(T->isThrowing(), "throws");
64876495
printFlag(T->hasSendingResult(), "sending_result");
6496+
printFlag(T->isCoroutine(), "@yield_once");
64886497
if (T->isDifferentiable()) {
64896498
switch (T->getDifferentiabilityKind()) {
64906499
default:

0 commit comments

Comments
 (0)