@@ -4552,6 +4552,19 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
45524552 return cache (TangentSpace::getTuple (tupleType));
45534553 }
45544554
4555+ // Yield result types are a bit special, but essentially tangent spaces of
4556+ // yields are yields of tangent space type.
4557+ if (auto *yieldResTy = getAs<YieldResultType>()) {
4558+ auto objectTanTy =
4559+ yieldResTy->getResultType ()->getAutoDiffTangentSpace (lookupConformance);
4560+ if (!objectTanTy)
4561+ return cache (std::nullopt );
4562+
4563+ auto *yieldTanType = YieldResultType::get (objectTanTy->getType (),
4564+ yieldResTy->isInOut ());
4565+ return cache (TangentSpace::getTangentVector (yieldTanType));
4566+ }
4567+
45554568 // For `Differentiable`-conforming types: the tangent space is the
45564569 // `TangentVector` associated type.
45574570 auto *differentiableProtocol =
@@ -4756,6 +4769,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
47564769
47574770 // Compute the result linear map function type.
47584771 FunctionType *linearMapType;
4772+ // FIXME: Verify ExtInfo state is correct, not working by accident.
4773+ FunctionType::ExtInfo info;
47594774 switch (kind) {
47604775 case AutoDiffLinearMapKind::Differential: {
47614776 // Compute the differential type, returned by JVP functions.
@@ -4814,6 +4829,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
48144829 // Case 2: original function has wrt `inout` parameters.
48154830 // - Original: `(T0, inout T1, ...) -> R`
48164831 // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
4832+ //
4833+ // Special case: yields. They act as parameters, so will
4834+ // always be on result side.
48174835 SmallVector<TupleTypeElt, 4 > pullbackResults;
48184836 SmallVector<AnyFunctionType::Param, 2 > semanticResultParams;
48194837 for (auto i : range (diffParams.size ())) {
@@ -4836,34 +4854,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
48364854 }
48374855 pullbackResults.emplace_back (paramTan->getType ());
48384856 }
4839- Type pullbackResult;
4840- if (pullbackResults.empty ()) {
4841- pullbackResult = ctx.TheEmptyTupleType ;
4842- } else if (pullbackResults.size () == 1 ) {
4843- pullbackResult = pullbackResults.front ().getType ();
4844- } else {
4845- pullbackResult = TupleType::get (pullbackResults, ctx);
4846- }
4847- // First accumulate non-inout results as pullback parameters.
4857+ // First accumulate ordinary result (not-semantic result parameters) as
4858+ // pullback parameters.
48484859 SmallVector<FunctionType::Param, 2 > pullbackParams;
48494860 for (auto i : range (resultTanTypes.size ())) {
48504861 auto resultTanType = resultTanTypes[i];
48514862 auto flags = ParameterTypeFlags ().withInOut (false );
4852- pullbackParams.push_back (AnyFunctionType::Param (
4853- resultTanType, Identifier (), flags));
4863+ if (resultTanType->is <YieldResultType>()) {
4864+ pullbackResults.emplace_back (resultTanType);
4865+ info = info.withCoroutine (true );
4866+ } else {
4867+ pullbackParams.push_back (AnyFunctionType::Param (
4868+ resultTanType, Identifier (), flags));
4869+ }
48544870 }
48554871 // Then append semantic result parameters.
48564872 for (auto i : range (semanticResultParams.size ())) {
48574873 auto semanticResultParam = semanticResultParams[i];
48584874 auto semanticResultParamType = semanticResultParam.getPlainType ();
48594875 auto semanticResultParamTan =
48604876 semanticResultParamType->getAutoDiffTangentSpace (lookupConformance);
4877+ assert (!semanticResultParamType->is <YieldResultType>() &&
4878+ " yields are always expected on result side" );
48614879 auto flags = ParameterTypeFlags ().withInOut (true );
48624880 pullbackParams.push_back (AnyFunctionType::Param (
48634881 semanticResultParamTan->getType (), Identifier (), flags));
48644882 }
4865- // FIXME: Verify ExtInfo state is correct, not working by accident.
4866- FunctionType::ExtInfo info;
4883+
4884+ Type pullbackResult;
4885+ if (pullbackResults.empty ()) {
4886+ pullbackResult = ctx.TheEmptyTupleType ;
4887+ } else if (pullbackResults.size () == 1 ) {
4888+ pullbackResult = pullbackResults.front ().getType ();
4889+ } else {
4890+ pullbackResult = TupleType::get (pullbackResults, ctx);
4891+ }
4892+
48674893 linearMapType = FunctionType::get (pullbackParams, pullbackResult, info);
48684894 break ;
48694895 }
0 commit comments