Skip to content

Commit 5f75587

Browse files
committed
Add basic boilerplate for AST coroutines and yields
1 parent 6abfe56 commit 5f75587

28 files changed

+203
-9
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7713,6 +7713,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
77137713
/// attribute.
77147714
bool isTransparent() const;
77157715

7716+
bool isCoroutine() const;
7717+
77167718
// Expose our import as member status
77177719
ImportAsMemberStatus getImportAsMemberStatus() const {
77187720
return ImportAsMemberStatus(Bits.AbstractFunctionDecl.IAMStatus);

include/swift/AST/DeclAttr.def

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,11 @@ DECL_ATTR(abi, ABI,
542542
165)
543543
DECL_ATTR_FEATURE_REQUIREMENT(ABI, ABIAttribute)
544544

545-
LAST_DECL_ATTR(ABI)
545+
SIMPLE_DECL_ATTR(yield_once, Coroutine,
546+
OnFunc | UserInaccessible | ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
547+
166)
548+
549+
LAST_DECL_ATTR(Coroutine)
546550

547551
#undef DECL_ATTR_ALIAS
548552
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/ExtInfo.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ enum class SILFunctionTypeRepresentation : uint8_t {
238238
CFunctionPointer = uint8_t(FunctionTypeRepresentation::CFunctionPointer),
239239

240240
/// The value of the greatest AST function representation.
241-
LastAST = CFunctionPointer,
241+
LastAST = uint8_t(FunctionTypeRepresentation::Last),
242242

243243
/// The value of the least SIL-only function representation.
244244
FirstSIL = 8,
@@ -438,8 +438,8 @@ class ASTExtInfoBuilder {
438438
// If bits are added or removed, then TypeBase::NumAFTExtInfoBits
439439
// and NumMaskBits must be updated, and they must match.
440440
//
441-
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult |
442-
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 |
441+
// |representation|noEscape|concurrent|async|throws|isolation|differentiability| SendingResult | coroutine |
442+
// | 0 .. 3 | 4 | 5 | 6 | 7 | 8 .. 10 | 11 .. 13 | 14 | 15 |
443443
//
444444
enum : unsigned {
445445
RepresentationMask = 0xF << 0,
@@ -452,7 +452,8 @@ class ASTExtInfoBuilder {
452452
DifferentiabilityMaskOffset = 11,
453453
DifferentiabilityMask = 0x7 << DifferentiabilityMaskOffset,
454454
SendingResultMask = 1 << 14,
455-
NumMaskBits = 15
455+
CoroutineMask = 1 << 15,
456+
NumMaskBits = 16
456457
};
457458

458459
static_assert(FunctionTypeIsolation::Mask == 0x7, "update mask manually");
@@ -531,6 +532,8 @@ class ASTExtInfoBuilder {
531532

532533
constexpr bool hasSendingResult() const { return bits & SendingResultMask; }
533534

535+
constexpr bool isCoroutine() const { return bits & CoroutineMask; }
536+
534537
constexpr DifferentiabilityKind getDifferentiabilityKind() const {
535538
return DifferentiabilityKind((bits & DifferentiabilityMask) >>
536539
DifferentiabilityMaskOffset);
@@ -647,6 +650,13 @@ class ASTExtInfoBuilder {
647650
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
648651
}
649652

653+
[[nodiscard]]
654+
ASTExtInfoBuilder withCoroutine(bool coroutine = true) const {
655+
return ASTExtInfoBuilder(
656+
coroutine ? (bits | CoroutineMask) : (bits & ~CoroutineMask),
657+
clangTypeInfo, globalActor, thrownError, lifetimeDependencies);
658+
}
659+
650660
[[nodiscard]]
651661
ASTExtInfoBuilder
652662
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
@@ -762,6 +772,8 @@ class ASTExtInfo {
762772

763773
constexpr bool isThrowing() const { return builder.isThrowing(); }
764774

775+
constexpr bool isCoroutine() const { return builder.isCoroutine(); }
776+
765777
constexpr bool hasSendingResult() const { return builder.hasSendingResult(); }
766778

767779
constexpr DifferentiabilityKind getDifferentiabilityKind() const {
@@ -825,6 +837,14 @@ class ASTExtInfo {
825837
return builder.withThrows(true, Type()).build();
826838
}
827839

840+
/// Helper method for changing only the coroutine field.
841+
///
842+
/// Prefer using \c ASTExtInfoBuilder::withCoroutine for chaining.
843+
[[nodiscard]]
844+
ASTExtInfo withCoroutine(bool coroutine = true) const {
845+
return builder.withCoroutine(coroutine).build();
846+
}
847+
828848
/// Helper method for changing only the async field.
829849
///
830850
/// Prefer using \c ASTExtInfoBuilder::withAsync for chaining.

include/swift/AST/TypeDifferenceVisitor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ class CanTypeDifferenceVisitor : public CanTypePairVisitor<Impl, bool> {
172172
type1->getElements(), type2->getElements());
173173
}
174174

175+
bool visitYieldResultType(CanYieldResultType type1, CanYieldResultType type2) {
176+
return asImpl().visit(type1.getResultType(), type2.getResultType());
177+
}
178+
175179
bool visitComponent(CanType type1, CanType type2,
176180
const TupleTypeElt &elt1, const TupleTypeElt &elt2) {
177181
if (elt1.getName() != elt2.getName())

include/swift/AST/TypeMatcher.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ class TypeMatcher {
153153
return mismatch(firstTuple.getPointer(), secondType, sugaredFirstType);
154154
}
155155

156+
bool visitYieldResultType(CanYieldResultType firstType, Type secondType,
157+
Type sugaredFirstType) {
158+
if (auto secondYieldType = secondType->getAs<YieldResultType>())
159+
if (!this->visit(firstType.getResultType(),
160+
secondYieldType->getResultType(),
161+
sugaredFirstType->getAs<YieldResultType>()->getResultType()))
162+
return false;
163+
164+
return mismatch(firstType.getPointer(), secondType, sugaredFirstType);
165+
}
166+
156167
bool visitSILPackType(CanSILPackType firstPack, Type secondType,
157168
Type sugaredFirstType) {
158169
if (auto secondPack = secondType->getAs<SILPackType>()) {

include/swift/AST/TypeNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ TYPE(InOut, Type)
205205
TYPE(Pack, Type)
206206
TYPE(PackExpansion, Type)
207207
TYPE(PackElement, Type)
208+
TYPE(YieldResult, Type)
208209
UNCHECKED_TYPE(TypeVariable, Type)
209210
UNCHECKED_TYPE(ErrorUnion, Type)
210211
ALWAYS_CANONICAL_TYPE(Integer, Type)

include/swift/AST/TypeTransform.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "swift/AST/GenericEnvironment.h"
2222
#include "swift/AST/SILLayout.h"
23+
#include "swift/AST/Types.h"
2324

2425
namespace swift {
2526

@@ -955,6 +956,16 @@ case TypeKind::Id:
955956
t : InOutType::get(objectTy);
956957
}
957958

959+
case TypeKind::YieldResult: {
960+
auto yield = cast<YieldResultType>(base);
961+
auto objectTy = doIt(yield->getResultType(), TypePosition::Invariant);
962+
if (!objectTy || objectTy->hasError())
963+
return objectTy;
964+
965+
return objectTy.getPointer() == yield->getResultType().getPointer() ?
966+
t : YieldResultType::get(objectTy, yield->isInOut());
967+
}
968+
958969
case TypeKind::Existential: {
959970
auto *existential = cast<ExistentialType>(base);
960971
auto constraint = doIt(existential->getConstraintType(), pos);

include/swift/AST/Types.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
402402
}
403403

404404
protected:
405-
enum { NumAFTExtInfoBits = 15 };
405+
enum { NumAFTExtInfoBits = 16 };
406406
enum { NumSILExtInfoBits = 14 };
407407

408408
// clang-format off
@@ -428,7 +428,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
428428
HasCachedType : 1
429429
);
430430

431-
SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+16,
431+
SWIFT_INLINE_BITFIELD_FULL(AnyFunctionType, TypeBase, NumAFTExtInfoBits+1+1+1+1+15,
432432
/// Extra information which affects how the function is called, like
433433
/// regparm and the calling convention.
434434
ExtInfoBits : NumAFTExtInfoBits,
@@ -437,7 +437,7 @@ class alignas(1 << TypeAlignInBits) TypeBase
437437
HasThrownError : 1,
438438
HasLifetimeDependencies : 1,
439439
: NumPadBits,
440-
NumParams : 16
440+
NumParams : 15
441441
);
442442

443443
SWIFT_INLINE_BITFIELD_FULL(ArchetypeType, TypeBase, 1+1+16,
@@ -1651,6 +1651,33 @@ class UnresolvedType : public TypeBase {
16511651
};
16521652
DEFINE_EMPTY_CAN_TYPE_WRAPPER(UnresolvedType, Type)
16531653

1654+
class YieldResultType : public TypeBase {
1655+
Type ResultType;
1656+
bool InOut = false;
1657+
1658+
YieldResultType(Type objectTy, bool InOut, const ASTContext *canonicalContext,
1659+
RecursiveTypeProperties properties)
1660+
: TypeBase(TypeKind::YieldResult, canonicalContext, properties),
1661+
ResultType(objectTy), InOut(InOut) {}
1662+
1663+
public:
1664+
static YieldResultType *get(Type originalType, bool InOut);
1665+
1666+
Type getResultType() const { return ResultType; }
1667+
bool isInOut() const { return InOut; }
1668+
1669+
// Implement isa/cast/dyncast/etc.
1670+
static bool classof(const TypeBase *T) {
1671+
return T->getKind() == TypeKind::YieldResult;
1672+
}
1673+
};
1674+
1675+
BEGIN_CAN_TYPE_WRAPPER(YieldResultType, Type)
1676+
PROXY_CAN_TYPE_SIMPLE_GETTER(getResultType)
1677+
static CanYieldResultType get(CanType type, bool InOut) {
1678+
return CanYieldResultType(YieldResultType::get(type, InOut));
1679+
}
1680+
END_CAN_TYPE_WRAPPER(YieldResultType, Type)
16541681

16551682
/// BuiltinType - An abstract class for all the builtin types.
16561683
class BuiltinType : public TypeBase {
@@ -3810,6 +3837,8 @@ class AnyFunctionType : public TypeBase {
38103837

38113838
bool isThrowing() const { return getExtInfo().isThrowing(); }
38123839

3840+
bool isCoroutine() const { return getExtInfo().isCoroutine(); }
3841+
38133842
bool hasSendingResult() const { return getExtInfo().hasSendingResult(); }
38143843

38153844
bool hasEffect(EffectKind kind) const;

lib/AST/ASTContext.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ struct ASTContext::Implementation {
546546
llvm::DenseMap<uintptr_t, ReferenceStorageType*> ReferenceStorageTypes;
547547
llvm::DenseMap<Type, LValueType*> LValueTypes;
548548
llvm::DenseMap<Type, InOutType*> InOutTypes;
549+
llvm::DenseMap<llvm::PointerIntPair<TypeBase*, 1, bool>,
550+
YieldResultType*> YieldResultTypes;
549551
llvm::DenseMap<std::pair<Type, void*>, DependentMemberType *>
550552
DependentMemberTypes;
551553
llvm::FoldingSet<ErrorUnionType> ErrorUnionTypes;
@@ -3247,6 +3249,7 @@ size_t ASTContext::Implementation::Arena::getTotalMemory() const {
32473249
llvm::capacity_in_bytes(ReferenceStorageTypes) +
32483250
llvm::capacity_in_bytes(LValueTypes) +
32493251
llvm::capacity_in_bytes(InOutTypes) +
3252+
llvm::capacity_in_bytes(YieldResultTypes) +
32503253
llvm::capacity_in_bytes(DependentMemberTypes) +
32513254
llvm::capacity_in_bytes(EnumTypes) +
32523255
llvm::capacity_in_bytes(StructTypes) +
@@ -3286,6 +3289,7 @@ void ASTContext::Implementation::Arena::dump(llvm::raw_ostream &os) const {
32863289
SIZE_AND_BYTES(ReferenceStorageTypes);
32873290
SIZE_AND_BYTES(LValueTypes);
32883291
SIZE_AND_BYTES(InOutTypes);
3292+
SIZE_AND_BYTES(YieldResultTypes);
32893293
SIZE_AND_BYTES(DependentMemberTypes);
32903294
SIZE(ErrorUnionTypes);
32913295
SIZE_AND_BYTES(PlaceholderTypes);
@@ -5477,6 +5481,28 @@ InOutType *InOutType::get(Type objectTy) {
54775481
properties);
54785482
}
54795483

5484+
YieldResultType *YieldResultType::get(Type objectTy, bool InOut) {
5485+
auto properties = objectTy->getRecursiveProperties();
5486+
if (InOut) {
5487+
assert(!objectTy->is<LValueType>() && !objectTy->is<InOutType>() &&
5488+
"cannot have 'inout' or @lvalue wrapped inside an 'inout yield'");
5489+
properties &= ~RecursiveTypeProperties::IsLValue;
5490+
}
5491+
5492+
auto arena = getArena(properties);
5493+
5494+
auto &C = objectTy->getASTContext();
5495+
auto pair = llvm::PointerIntPair<TypeBase*, 1, bool>(objectTy.getPointer(),
5496+
InOut);
5497+
auto &entry = C.getImpl().getArena(arena).YieldResultTypes[pair];
5498+
if (entry)
5499+
return entry;
5500+
5501+
const ASTContext *canonicalContext = objectTy->isCanonical() ? &C : nullptr;
5502+
return entry = new (C, arena) YieldResultType(objectTy, InOut, canonicalContext,
5503+
properties);
5504+
}
5505+
54805506
DependentMemberType *DependentMemberType::get(Type base, Identifier name) {
54815507
auto properties = base->getRecursiveProperties();
54825508
properties |= RecursiveTypeProperties::HasDependentMember;

lib/AST/ASTDumper.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,6 +3768,7 @@ class PrintAttribute : public AttributeVisitor<PrintAttribute, void, StringRef>,
37683768
TRIVIAL_ATTR_PRINTER(CompilerInitialized, compiler_initialized)
37693769
TRIVIAL_ATTR_PRINTER(Consuming, consuming)
37703770
TRIVIAL_ATTR_PRINTER(Convenience, convenience)
3771+
TRIVIAL_ATTR_PRINTER(Coroutine, coroutine)
37713772
TRIVIAL_ATTR_PRINTER(DiscardableResult, discardable_result)
37723773
TRIVIAL_ATTR_PRINTER(DisfavoredOverload, disfavored_overload)
37733774
TRIVIAL_ATTR_PRINTER(DistributedActor, distributed_actor)
@@ -4678,6 +4679,13 @@ namespace {
46784679
printFoot();
46794680
}
46804681

4682+
void visitYieldResultType(YieldResultType *T, StringRef label) {
4683+
printCommon("yield", label);
4684+
printFlag(T->isInOut(), "inout");
4685+
printRec(T->getResultType(), "type");
4686+
printFoot();
4687+
}
4688+
46814689
TRIVIAL_TYPE_PRINTER(Unresolved, unresolved)
46824690

46834691
void visitPlaceholderType(PlaceholderType *T, StringRef label) {
@@ -5069,6 +5077,7 @@ namespace {
50695077
printFlag(T->isAsync(), "async");
50705078
printFlag(T->isThrowing(), "throws");
50715079
printFlag(T->hasSendingResult(), "sending_result");
5080+
printFlag(T->isCoroutine(), "@yield_once");
50725081
if (T->isDifferentiable()) {
50735082
switch (T->getDifferentiabilityKind()) {
50745083
default:

0 commit comments

Comments
 (0)