@@ -6239,6 +6239,10 @@ static bool checkFunctionSignature(
62396239 }))
62406240 return false ;
62416241
6242+ // Check that either both are coroutines or none
6243+ if (required->isCoroutine () != candidateFnTy->isCoroutine ())
6244+ return false ;
6245+
62426246 // If required result type is not a function type, check that result types
62436247 // match exactly.
62446248 auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult ());
@@ -6264,19 +6268,17 @@ static bool checkFunctionSignature(
62646268 return checkFunctionSignature (requiredResultFnTy, candidateResultTy);
62656269}
62666270
6267- // / Returns an `AnyFunctionType` from the given parameters, result type, and
6268- // / generic signature.
6271+ // / Returns an `AnyFunctionType` from the given parameters, result type,
6272+ // / generic signature, and `ExtInfo`
62696273static AnyFunctionType *
62706274makeFunctionType (ArrayRef<AnyFunctionType::Param> parameters, Type resultType,
6271- GenericSignature genericSignature) {
6272- // FIXME: Verify ExtInfo state is correct, not working by accident.
6275+ GenericSignature genericSignature,
6276+ AnyFunctionType::ExtInfo extInfo) {
62736277 if (genericSignature) {
6274- GenericFunctionType::ExtInfo info;
62756278 return GenericFunctionType::get (genericSignature, parameters, resultType,
6276- info );
6279+ extInfo );
62776280 }
6278- FunctionType::ExtInfo info;
6279- return FunctionType::get (parameters, resultType, info);
6281+ return FunctionType::get (parameters, resultType, extInfo);
62806282}
62816283
62826284// / Computes the original function type corresponding to the given derivative
@@ -6295,14 +6297,16 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
62956297 currentLevel = currentLevel->getResult ()->getAs <AnyFunctionType>();
62966298 }
62976299
6298- auto derivativeResult = curryLevels.back ()->getResult ()->getAs <TupleType>();
6300+ AnyFunctionType *lastType = curryLevels.back ();
6301+ auto derivativeResult = lastType->getResult ()->getAs <TupleType>();
62996302 assert (derivativeResult && derivativeResult->getNumElements () == 2 &&
63006303 " Expected derivative result to be a two-element tuple" );
63016304 auto originalResult = derivativeResult->getElement (0 ).getType ();
63026305 auto *originalType = makeFunctionType (
6303- curryLevels. back () ->getParams (), originalResult,
6306+ lastType ->getParams (), originalResult,
63046307 curryLevels.size () == 1 ? derivativeFnTy->getOptGenericSignature ()
6305- : nullptr );
6308+ : nullptr ,
6309+ lastType->getExtInfo ());
63066310
63076311 // Wrap the derivative function type in additional curry levels.
63086312 auto curryLevelsWithoutLast =
@@ -6314,7 +6318,8 @@ getDerivativeOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
63146318 makeFunctionType (curryLevel->getParams (), originalType,
63156319 i == curryLevelsWithoutLast.size () - 1
63166320 ? derivativeFnTy->getOptGenericSignature ()
6317- : nullptr );
6321+ : nullptr ,
6322+ curryLevel->getExtInfo ());
63186323 }
63196324 return originalType;
63206325}
@@ -6331,10 +6336,12 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
63316336 auto transposeParams = transposeFnType->getParams ();
63326337 auto transposeResult = transposeFnType->getResult ();
63336338 bool isCurried = transposeResult->is <AnyFunctionType>();
6339+ AnyFunctionType::ExtInfo innerInfo;
63346340 if (isCurried) {
63356341 auto methodType = transposeResult->castTo <AnyFunctionType>();
63366342 transposeParams = methodType->getParams ();
63376343 transposeResult = methodType->getResult ();
6344+ innerInfo = methodType->getExtInfo ();
63386345 }
63396346
63406347 // Get the original function's result type.
@@ -6402,16 +6409,19 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
64026409 // `(Self) -> (<original parameters>) -> <original result>`.
64036410 if (isCurried) {
64046411 assert (selfType && " `Self` type should be resolved" );
6405- originalType = makeFunctionType (originalParams, originalResult, nullptr );
6412+ originalType = makeFunctionType (originalParams, originalResult, nullptr ,
6413+ innerInfo);
64066414 originalType =
64076415 makeFunctionType (AnyFunctionType::Param (selfType), originalType,
6408- transposeFnType->getOptGenericSignature ());
6416+ transposeFnType->getOptGenericSignature (),
6417+ transposeFnType->getExtInfo ());
64096418 }
64106419 // Otherwise, the original function type is simply:
64116420 // `(<original parameters>) -> <original result>`.
64126421 else {
64136422 originalType = makeFunctionType (originalParams, originalResult,
6414- transposeFnType->getOptGenericSignature ());
6423+ transposeFnType->getOptGenericSignature (),
6424+ transposeFnType->getExtInfo ());
64156425 }
64166426 return originalType;
64176427}
@@ -6985,10 +6995,12 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
69856995 assert (derivativeTypeCtx);
69866996
69876997 // Diagnose unsupported original accessor kinds.
6988- // Currently, only getters and setters are supported.
6998+ // Currently, only getters, setters and _modify accessor are supported.
6999+ // FIXME: Support modify accessors (aka Modify2)
69897000 if (originalName.AccessorKind .has_value ()) {
6990- if (*originalName.AccessorKind != AccessorKind::Get &&
6991- *originalName.AccessorKind != AccessorKind::Set) {
7001+ AccessorKind kind = *originalName.AccessorKind ;
7002+ if (kind != AccessorKind::Get && kind != AccessorKind::Set &&
7003+ kind != AccessorKind::Modify) {
69927004 attr->setInvalid ();
69937005 diags.diagnose (
69947006 originalName.Loc , diag::derivative_attr_unsupported_accessor_kind,
0 commit comments