Skip to content

Commit e9ad117

Browse files
committed
Allow custom _modify accessor derivative registration
1 parent 16e40e6 commit e9ad117

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6239,6 +6239,10 @@ static bool checkFunctionSignature(
62396239
}))
62406240
return false;
62416241

6242+
// Check that either both are coroutines or none
6243+
if (required->isCoroutine() != candidateFnTy->isCoroutine())
6244+
return false;
6245+
62426246
// If required result type is not a function type, check that result types
62436247
// match exactly.
62446248
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult());
@@ -6264,19 +6268,17 @@ static bool checkFunctionSignature(
62646268
return checkFunctionSignature(requiredResultFnTy, candidateResultTy);
62656269
}
62666270

6267-
/// Returns an `AnyFunctionType` from the given parameters, result type, and
6268-
/// generic signature.
6271+
/// Returns an `AnyFunctionType` from the given parameters, result type,
6272+
/// generic signature, and `ExtInfo`
62696273
static AnyFunctionType *
62706274
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
6271-
GenericSignature genericSignature) {
6272-
// FIXME: Verify ExtInfo state is correct, not working by accident.
6275+
GenericSignature genericSignature,
6276+
AnyFunctionType::ExtInfo extInfo) {
62736277
if (genericSignature) {
6274-
GenericFunctionType::ExtInfo info;
62756278
return GenericFunctionType::get(genericSignature, parameters, resultType,
6276-
info);
6279+
extInfo);
62776280
}
6278-
FunctionType::ExtInfo info;
6279-
return FunctionType::get(parameters, resultType, info);
6281+
return FunctionType::get(parameters, resultType, extInfo);
62806282
}
62816283

62826284
/// Computes the original function type corresponding to the given derivative
@@ -6295,14 +6297,16 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
62956297
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
62966298
}
62976299

6298-
auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
6300+
AnyFunctionType *lastType = curryLevels.back();
6301+
auto derivativeResult = lastType->getResult()->getAs<TupleType>();
62996302
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
63006303
"Expected derivative result to be a two-element tuple");
63016304
auto originalResult = derivativeResult->getElement(0).getType();
63026305
auto *originalType = makeFunctionType(
6303-
curryLevels.back()->getParams(), originalResult,
6306+
lastType->getParams(), originalResult,
63046307
curryLevels.size() == 1 ? derivativeFnTy->getOptGenericSignature()
6305-
: nullptr);
6308+
: nullptr,
6309+
lastType->getExtInfo());
63066310

63076311
// Wrap the derivative function type in additional curry levels.
63086312
auto curryLevelsWithoutLast =
@@ -6314,7 +6318,8 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
63146318
makeFunctionType(curryLevel->getParams(), originalType,
63156319
i == curryLevelsWithoutLast.size() - 1
63166320
? derivativeFnTy->getOptGenericSignature()
6317-
: nullptr);
6321+
: nullptr,
6322+
curryLevel->getExtInfo());
63186323
}
63196324
return originalType;
63206325
}
@@ -6331,10 +6336,12 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
63316336
auto transposeParams = transposeFnType->getParams();
63326337
auto transposeResult = transposeFnType->getResult();
63336338
bool isCurried = transposeResult->is<AnyFunctionType>();
6339+
AnyFunctionType::ExtInfo innerInfo;
63346340
if (isCurried) {
63356341
auto methodType = transposeResult->castTo<AnyFunctionType>();
63366342
transposeParams = methodType->getParams();
63376343
transposeResult = methodType->getResult();
6344+
innerInfo = methodType->getExtInfo();
63386345
}
63396346

63406347
// Get the original function's result type.
@@ -6402,16 +6409,19 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
64026409
// `(Self) -> (<original parameters>) -> <original result>`.
64036410
if (isCurried) {
64046411
assert(selfType && "`Self` type should be resolved");
6405-
originalType = makeFunctionType(originalParams, originalResult, nullptr);
6412+
originalType = makeFunctionType(originalParams, originalResult, nullptr,
6413+
innerInfo);
64066414
originalType =
64076415
makeFunctionType(AnyFunctionType::Param(selfType), originalType,
6408-
transposeFnType->getOptGenericSignature());
6416+
transposeFnType->getOptGenericSignature(),
6417+
transposeFnType->getExtInfo());
64096418
}
64106419
// Otherwise, the original function type is simply:
64116420
// `(<original parameters>) -> <original result>`.
64126421
else {
64136422
originalType = makeFunctionType(originalParams, originalResult,
6414-
transposeFnType->getOptGenericSignature());
6423+
transposeFnType->getOptGenericSignature(),
6424+
transposeFnType->getExtInfo());
64156425
}
64166426
return originalType;
64176427
}
@@ -6985,10 +6995,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
69856995
assert(derivativeTypeCtx);
69866996

69876997
// Diagnose unsupported original accessor kinds.
6988-
// Currently, only getters and setters are supported.
6998+
// Currently, only getters, setters and _modify accessor are supported.
6999+
// FIXME: Support modify accessors (aka Modify2)
69897000
if (originalName.AccessorKind.has_value()) {
6990-
if (*originalName.AccessorKind != AccessorKind::Get &&
6991-
*originalName.AccessorKind != AccessorKind::Set) {
7001+
AccessorKind kind = *originalName.AccessorKind;
7002+
if (kind != AccessorKind::Get && kind != AccessorKind::Set &&
7003+
kind != AccessorKind::Modify) {
69927004
attr->setInvalid();
69937005
diags.diagnose(
69947006
originalName.Loc, diag::derivative_attr_unsupported_accessor_kind,

0 commit comments

Comments
 (0)