@@ -5849,6 +5849,10 @@ static bool checkFunctionSignature(
58495849 }))
58505850 return false ;
58515851
5852+ // Check that either both are coroutines or none
5853+ if (required->isCoroutine () != candidateFnTy->isCoroutine ())
5854+ return false ;
5855+
58525856 // If required result type is not a function type, check that result types
58535857 // match exactly.
58545858 auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult ());
@@ -5874,19 +5878,17 @@ static bool checkFunctionSignature(
58745878 return checkFunctionSignature (requiredResultFnTy, candidateResultTy);
58755879}
58765880
5877- // / Returns an `AnyFunctionType` from the given parameters, result type, and
5878- // / generic signature.
5881+ // / Returns an `AnyFunctionType` from the given parameters, result type,
5882+ // / generic signature, and `ExtInfo`
58795883static AnyFunctionType *
58805884makeFunctionType (ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
5881- GenericSignature genericSignature) {
5882- // FIXME: Verify ExtInfo state is correct, not working by accident.
5885+ GenericSignature genericSignature,
5886+ AnyFunctionType::ExtInfo extInfo) {
58835887 if (genericSignature) {
5884- GenericFunctionType::ExtInfo info;
58855888 return GenericFunctionType::get (genericSignature, parameters, resultType,
5886- info );
5889+ extInfo );
58875890 }
5888- FunctionType::ExtInfo info;
5889- return FunctionType::get (parameters, resultType, info);
5891+ return FunctionType::get (parameters, resultType, extInfo);
58905892}
58915893
58925894// / Computes the original function type corresponding to the given derivative
@@ -5905,14 +5907,16 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
59055907 currentLevel = currentLevel->getResult ()->getAs <AnyFunctionType>();
59065908 }
59075909
5908- auto derivativeResult = curryLevels.back ()->getResult ()->getAs <TupleType>();
5910+ AnyFunctionType *lastType = curryLevels.back ();
5911+ auto derivativeResult = lastType->getResult ()->getAs <TupleType>();
59095912 assert (derivativeResult && derivativeResult->getNumElements () == 2 &&
59105913 " Expected derivative result to be a two-element tuple" );
59115914 auto originalResult = derivativeResult->getElement (0 ).getType ();
59125915 auto *originalType = makeFunctionType (
5913- curryLevels. back () ->getParams (), originalResult,
5916+ lastType ->getParams (), originalResult,
59145917 curryLevels.size () == 1 ? derivativeFnTy->getOptGenericSignature ()
5915- : nullptr );
5918+ : nullptr ,
5919+ lastType->getExtInfo ());
59165920
59175921 // Wrap the derivative function type in additional curry levels.
59185922 auto curryLevelsWithoutLast =
@@ -5924,7 +5928,8 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
59245928 makeFunctionType (curryLevel->getParams (), originalType,
59255929 i == curryLevelsWithoutLast.size () - 1
59265930 ? derivativeFnTy->getOptGenericSignature ()
5927- : nullptr );
5931+ : nullptr ,
5932+ curryLevel->getExtInfo ());
59285933 }
59295934 return originalType;
59305935}
@@ -5941,10 +5946,12 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
59415946 auto transposeParams = transposeFnType->getParams ();
59425947 auto transposeResult = transposeFnType->getResult ();
59435948 bool isCurried = transposeResult->is <AnyFunctionType>();
5949+ AnyFunctionType::ExtInfo innerInfo;
59445950 if (isCurried) {
59455951 auto methodType = transposeResult->castTo <AnyFunctionType>();
59465952 transposeParams = methodType->getParams ();
59475953 transposeResult = methodType->getResult ();
5954+ innerInfo = methodType->getExtInfo ();
59485955 }
59495956
59505957 // Get the original function's result type.
@@ -6012,16 +6019,19 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
60126019 // `(Self) -> (<original parameters>) -> <original result>`.
60136020 if (isCurried) {
60146021 assert (selfType && " `Self` type should be resolved" );
6015- originalType = makeFunctionType (originalParams, originalResult, nullptr );
6022+ originalType = makeFunctionType (originalParams, originalResult, nullptr ,
6023+ innerInfo);
60166024 originalType =
60176025 makeFunctionType (AnyFunctionType::Param (selfType), originalType,
6018- transposeFnType->getOptGenericSignature ());
6026+ transposeFnType->getOptGenericSignature (),
6027+ transposeFnType->getExtInfo ());
60196028 }
60206029 // Otherwise, the original function type is simply:
60216030 // `(<original parameters>) -> <original result>`.
60226031 else {
60236032 originalType = makeFunctionType (originalParams, originalResult,
6024- transposeFnType->getOptGenericSignature ());
6033+ transposeFnType->getOptGenericSignature (),
6034+ transposeFnType->getExtInfo ());
60256035 }
60266036 return originalType;
60276037}
@@ -6590,10 +6600,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
65906600 assert (derivativeTypeCtx);
65916601
65926602 // Diagnose unsupported original accessor kinds.
6593- // Currently, only getters and setters are supported.
6603+ // Currently, only getters, setters and _modify accessor are supported.
6604+ // FIXME: Support modify accessors (aka Modify2)
65946605 if (originalName.AccessorKind .has_value ()) {
6595- if (*originalName.AccessorKind != AccessorKind::Get &&
6596- *originalName.AccessorKind != AccessorKind::Set) {
6606+ AccessorKind kind = *originalName.AccessorKind ;
6607+ if (kind != AccessorKind::Get && kind != AccessorKind::Set &&
6608+ kind != AccessorKind::Modify) {
65976609 attr->setInvalid ();
65986610 diags.diagnose (
65996611 originalName.Loc , diag::derivative_attr_unsupported_accessor_kind,
0 commit comments