Skip to content

Commit 16fef90

Browse files
committed
Correct pullback type calculatio in presence of yields
1 parent b3e6915 commit 16fef90

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

lib/AST/Type.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4552,6 +4552,19 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
45524552
return cache(TangentSpace::getTuple(tupleType));
45534553
}
45544554

4555+
// Yield result types are a bit special, but essentially tangent spaces of
4556+
// yields are yields of tangent space type.
4557+
if (auto *yieldResTy = getAs<YieldResultType>()) {
4558+
auto objectTanTy =
4559+
yieldResTy->getResultType()->getAutoDiffTangentSpace(lookupConformance);
4560+
if (!objectTanTy)
4561+
return cache(std::nullopt);
4562+
4563+
auto *yieldTanType = YieldResultType::get(objectTanTy->getType(),
4564+
yieldResTy->isInOut());
4565+
return cache(TangentSpace::getTangentVector(yieldTanType));
4566+
}
4567+
45554568
// For `Differentiable`-conforming types: the tangent space is the
45564569
// `TangentVector` associated type.
45574570
auto *differentiableProtocol =
@@ -4756,6 +4769,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
47564769

47574770
// Compute the result linear map function type.
47584771
FunctionType *linearMapType;
4772+
// FIXME: Verify ExtInfo state is correct, not working by accident.
4773+
FunctionType::ExtInfo info;
47594774
switch (kind) {
47604775
case AutoDiffLinearMapKind::Differential: {
47614776
// Compute the differential type, returned by JVP functions.
@@ -4814,6 +4829,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
48144829
// Case 2: original function has wrt `inout` parameters.
48154830
// - Original: `(T0, inout T1, ...) -> R`
48164831
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
4832+
//
4833+
// Special case: yields. They act as parameters, so will
4834+
// always be on result side.
48174835
SmallVector<TupleTypeElt, 4> pullbackResults;
48184836
SmallVector<AnyFunctionType::Param, 2> semanticResultParams;
48194837
for (auto i : range(diffParams.size())) {
@@ -4836,34 +4854,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
48364854
}
48374855
pullbackResults.emplace_back(paramTan->getType());
48384856
}
4839-
Type pullbackResult;
4840-
if (pullbackResults.empty()) {
4841-
pullbackResult = ctx.TheEmptyTupleType;
4842-
} else if (pullbackResults.size() == 1) {
4843-
pullbackResult = pullbackResults.front().getType();
4844-
} else {
4845-
pullbackResult = TupleType::get(pullbackResults, ctx);
4846-
}
4847-
// First accumulate non-inout results as pullback parameters.
4857+
// First accumulate ordinary result (not-semantic result parameters) as
4858+
// pullback parameters.
48484859
SmallVector<FunctionType::Param, 2> pullbackParams;
48494860
for (auto i : range(resultTanTypes.size())) {
48504861
auto resultTanType = resultTanTypes[i];
48514862
auto flags = ParameterTypeFlags().withInOut(false);
4852-
pullbackParams.push_back(AnyFunctionType::Param(
4853-
resultTanType, Identifier(), flags));
4863+
if (resultTanType->is<YieldResultType>()) {
4864+
pullbackResults.emplace_back(resultTanType);
4865+
info = info.withCoroutine(true);
4866+
} else {
4867+
pullbackParams.push_back(AnyFunctionType::Param(
4868+
resultTanType, Identifier(), flags));
4869+
}
48544870
}
48554871
// Then append semantic result parameters.
48564872
for (auto i : range(semanticResultParams.size())) {
48574873
auto semanticResultParam = semanticResultParams[i];
48584874
auto semanticResultParamType = semanticResultParam.getPlainType();
48594875
auto semanticResultParamTan =
48604876
semanticResultParamType->getAutoDiffTangentSpace(lookupConformance);
4877+
assert(!semanticResultParamType->is<YieldResultType>() &&
4878+
"yields are always expected on result side");
48614879
auto flags = ParameterTypeFlags().withInOut(true);
48624880
pullbackParams.push_back(AnyFunctionType::Param(
48634881
semanticResultParamTan->getType(), Identifier(), flags));
48644882
}
4865-
// FIXME: Verify ExtInfo state is correct, not working by accident.
4866-
FunctionType::ExtInfo info;
4883+
4884+
Type pullbackResult;
4885+
if (pullbackResults.empty()) {
4886+
pullbackResult = ctx.TheEmptyTupleType;
4887+
} else if (pullbackResults.size() == 1) {
4888+
pullbackResult = pullbackResults.front().getType();
4889+
} else {
4890+
pullbackResult = TupleType::get(pullbackResults, ctx);
4891+
}
4892+
48674893
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
48684894
break;
48694895
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,10 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
451451
fatalError()
452452
}
453453

454-
// expected-error @+1 {{cannot register derivative for _modify accessor}}
455454
@derivative(of: computedProperty._modify)
456-
mutating func vjpPropertyModify(_ newValue: T) -> (
457-
value: (), pullback: (inout TangentVector) -> T.TangentVector
455+
@yield_once
456+
mutating func vjpPropertyModify() -> (
457+
value: inout @yields T, pullback: @yield_once (inout TangentVector) -> inout @yields T.TangentVector
458458
) {
459459
fatalError()
460460
}

0 commit comments

Comments
 (0)