Skip to content

Commit b3e6915

Browse files
committed
Allow custom _modify accessor derivative registration
1 parent 6bc2faa commit b3e6915

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5849,6 +5849,10 @@ static bool checkFunctionSignature(
58495849
}))
58505850
return false;
58515851

5852+
// Check that either both are coroutines or none
5853+
if (required->isCoroutine() != candidateFnTy->isCoroutine())
5854+
return false;
5855+
58525856
// If required result type is not a function type, check that result types
58535857
// match exactly.
58545858
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult());
@@ -5874,19 +5878,17 @@ static bool checkFunctionSignature(
58745878
return checkFunctionSignature(requiredResultFnTy, candidateResultTy);
58755879
}
58765880

5877-
/// Returns an `AnyFunctionType` from the given parameters, result type, and
5878-
/// generic signature.
5881+
/// Returns an `AnyFunctionType` from the given parameters, result type,
5882+
/// generic signature, and `ExtInfo`
58795883
static AnyFunctionType *
58805884
makeFunctionType(ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
5881-
GenericSignature genericSignature) {
5882-
// FIXME: Verify ExtInfo state is correct, not working by accident.
5885+
GenericSignature genericSignature,
5886+
AnyFunctionType::ExtInfo extInfo) {
58835887
if (genericSignature) {
5884-
GenericFunctionType::ExtInfo info;
58855888
return GenericFunctionType::get(genericSignature, parameters, resultType,
5886-
info);
5889+
extInfo);
58875890
}
5888-
FunctionType::ExtInfo info;
5889-
return FunctionType::get(parameters, resultType, info);
5891+
return FunctionType::get(parameters, resultType, extInfo);
58905892
}
58915893

58925894
/// Computes the original function type corresponding to the given derivative
@@ -5905,14 +5907,16 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
59055907
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
59065908
}
59075909

5908-
auto derivativeResult = curryLevels.back()->getResult()->getAs<TupleType>();
5910+
AnyFunctionType *lastType = curryLevels.back();
5911+
auto derivativeResult = lastType->getResult()->getAs<TupleType>();
59095912
assert(derivativeResult && derivativeResult->getNumElements() == 2 &&
59105913
"Expected derivative result to be a two-element tuple");
59115914
auto originalResult = derivativeResult->getElement(0).getType();
59125915
auto *originalType = makeFunctionType(
5913-
curryLevels.back()->getParams(), originalResult,
5916+
lastType->getParams(), originalResult,
59145917
curryLevels.size() == 1 ? derivativeFnTy->getOptGenericSignature()
5915-
: nullptr);
5918+
: nullptr,
5919+
lastType->getExtInfo());
59165920

59175921
// Wrap the derivative function type in additional curry levels.
59185922
auto curryLevelsWithoutLast =
@@ -5924,7 +5928,8 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
59245928
makeFunctionType(curryLevel->getParams(), originalType,
59255929
i == curryLevelsWithoutLast.size() - 1
59265930
? derivativeFnTy->getOptGenericSignature()
5927-
: nullptr);
5931+
: nullptr,
5932+
curryLevel->getExtInfo());
59285933
}
59295934
return originalType;
59305935
}
@@ -5941,10 +5946,12 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
59415946
auto transposeParams = transposeFnType->getParams();
59425947
auto transposeResult = transposeFnType->getResult();
59435948
bool isCurried = transposeResult->is<AnyFunctionType>();
5949+
AnyFunctionType::ExtInfo innerInfo;
59445950
if (isCurried) {
59455951
auto methodType = transposeResult->castTo<AnyFunctionType>();
59465952
transposeParams = methodType->getParams();
59475953
transposeResult = methodType->getResult();
5954+
innerInfo = methodType->getExtInfo();
59485955
}
59495956

59505957
// Get the original function's result type.
@@ -6012,16 +6019,19 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
60126019
// `(Self) -> (<original parameters>) -> <original result>`.
60136020
if (isCurried) {
60146021
assert(selfType && "`Self` type should be resolved");
6015-
originalType = makeFunctionType(originalParams, originalResult, nullptr);
6022+
originalType = makeFunctionType(originalParams, originalResult, nullptr,
6023+
innerInfo);
60166024
originalType =
60176025
makeFunctionType(AnyFunctionType::Param(selfType), originalType,
6018-
transposeFnType->getOptGenericSignature());
6026+
transposeFnType->getOptGenericSignature(),
6027+
transposeFnType->getExtInfo());
60196028
}
60206029
// Otherwise, the original function type is simply:
60216030
// `(<original parameters>) -> <original result>`.
60226031
else {
60236032
originalType = makeFunctionType(originalParams, originalResult,
6024-
transposeFnType->getOptGenericSignature());
6033+
transposeFnType->getOptGenericSignature(),
6034+
transposeFnType->getExtInfo());
60256035
}
60266036
return originalType;
60276037
}
@@ -6590,10 +6600,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
65906600
assert(derivativeTypeCtx);
65916601

65926602
// Diagnose unsupported original accessor kinds.
6593-
// Currently, only getters and setters are supported.
6603+
// Currently, only getters, setters and _modify accessor are supported.
6604+
// FIXME: Support modify accessors (aka Modify2)
65946605
if (originalName.AccessorKind.has_value()) {
6595-
if (*originalName.AccessorKind != AccessorKind::Get &&
6596-
*originalName.AccessorKind != AccessorKind::Set) {
6606+
AccessorKind kind = *originalName.AccessorKind;
6607+
if (kind != AccessorKind::Get && kind != AccessorKind::Set &&
6608+
kind != AccessorKind::Modify) {
65976609
attr->setInvalid();
65986610
diags.diagnose(
65996611
originalName.Loc, diag::derivative_attr_unsupported_accessor_kind,

0 commit comments

Comments
 (0)