@@ -4771,6 +4771,19 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
47714771 return cache (TangentSpace::getTuple (tupleType));
47724772 }
47734773
4774+ // Yield result types are a bit special, but essentially tangent spaces of
4775+ // yields are yields of tangent space type.
4776+ if (auto *yieldResTy = getAs<YieldResultType>()) {
4777+ auto objectTanTy =
4778+ yieldResTy->getResultType ()->getAutoDiffTangentSpace (lookupConformance);
4779+ if (!objectTanTy)
4780+ return cache (std::nullopt );
4781+
4782+ auto *yieldTanType = YieldResultType::get (objectTanTy->getType (),
4783+ yieldResTy->isInOut ());
4784+ return cache (TangentSpace::getTangentVector (yieldTanType));
4785+ }
4786+
47744787 // For `Differentiable`-conforming types: the tangent space is the
47754788 // `TangentVector` associated type.
47764789 auto *differentiableProtocol =
@@ -4993,6 +5006,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
49935006
49945007 // Compute the result linear map function type.
49955008 FunctionType *linearMapType;
5009+ // FIXME: Verify ExtInfo state is correct, not working by accident.
5010+ FunctionType::ExtInfo info;
49965011 switch (kind) {
49975012 case AutoDiffLinearMapKind::Differential: {
49985013 // Compute the differential type, returned by JVP functions.
@@ -5051,6 +5066,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
50515066 // Case 2: original function has wrt `inout` parameters.
50525067 // - Original: `(T0, inout T1, ...) -> R`
50535068 // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
5069+ //
5070+ // Special case: yields. They act as parameters, so will
5071+ // always be on result side.
50545072 SmallVector<TupleTypeElt, 4 > pullbackResults;
50555073 SmallVector<AnyFunctionType::Param, 2 > semanticResultParams;
50565074 for (auto i : range (diffParams.size ())) {
@@ -5073,34 +5091,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
50735091 }
50745092 pullbackResults.emplace_back (paramTan->getType ());
50755093 }
5076- Type pullbackResult;
5077- if (pullbackResults.empty ()) {
5078- pullbackResult = ctx.TheEmptyTupleType ;
5079- } else if (pullbackResults.size () == 1 ) {
5080- pullbackResult = pullbackResults.front ().getType ();
5081- } else {
5082- pullbackResult = TupleType::get (pullbackResults, ctx);
5083- }
5084- // First accumulate non-inout results as pullback parameters.
5094+ // First accumulate ordinary result (not-semantic result parameters) as
5095+ // pullback parameters.
50855096 SmallVector<FunctionType::Param, 2 > pullbackParams;
50865097 for (auto i : range (resultTanTypes.size ())) {
50875098 auto resultTanType = resultTanTypes[i];
50885099 auto flags = ParameterTypeFlags ().withInOut (false );
5089- pullbackParams.push_back (AnyFunctionType::Param (
5090- resultTanType, Identifier (), flags));
5100+ if (resultTanType->is <YieldResultType>()) {
5101+ pullbackResults.emplace_back (resultTanType);
5102+ info = info.withCoroutine (true );
5103+ } else {
5104+ pullbackParams.push_back (AnyFunctionType::Param (
5105+ resultTanType, Identifier (), flags));
5106+ }
50915107 }
50925108 // Then append semantic result parameters.
50935109 for (auto i : range (semanticResultParams.size ())) {
50945110 auto semanticResultParam = semanticResultParams[i];
50955111 auto semanticResultParamType = semanticResultParam.getPlainType ();
50965112 auto semanticResultParamTan =
50975113 semanticResultParamType->getAutoDiffTangentSpace (lookupConformance);
5114+ assert (!semanticResultParamType->is <YieldResultType>() &&
5115+ " yields are always expected on result side" );
50985116 auto flags = ParameterTypeFlags ().withInOut (true );
50995117 pullbackParams.push_back (AnyFunctionType::Param (
51005118 semanticResultParamTan->getType (), Identifier (), flags));
51015119 }
5102- // FIXME: Verify ExtInfo state is correct, not working by accident.
5103- FunctionType::ExtInfo info;
5120+
5121+ Type pullbackResult;
5122+ if (pullbackResults.empty ()) {
5123+ pullbackResult = ctx.TheEmptyTupleType ;
5124+ } else if (pullbackResults.size () == 1 ) {
5125+ pullbackResult = pullbackResults.front ().getType ();
5126+ } else {
5127+ pullbackResult = TupleType::get (pullbackResults, ctx);
5128+ }
5129+
51045130 linearMapType = FunctionType::get (pullbackParams, pullbackResult, info);
51055131 break ;
51065132 }
0 commit comments