Skip to content

Commit d2289c5

Browse files
committed
Correct pullback type calculatio in presence of yields
1 parent e9ad117 commit d2289c5

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

lib/AST/Type.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4771,6 +4771,19 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
47714771
return cache(TangentSpace::getTuple(tupleType));
47724772
}
47734773

4774+
// Yield result types are a bit special, but essentially tangent spaces of
4775+
// yields are yields of tangent space type.
4776+
if (auto *yieldResTy = getAs<YieldResultType>()) {
4777+
auto objectTanTy =
4778+
yieldResTy->getResultType()->getAutoDiffTangentSpace(lookupConformance);
4779+
if (!objectTanTy)
4780+
return cache(std::nullopt);
4781+
4782+
auto *yieldTanType = YieldResultType::get(objectTanTy->getType(),
4783+
yieldResTy->isInOut());
4784+
return cache(TangentSpace::getTangentVector(yieldTanType));
4785+
}
4786+
47744787
// For `Differentiable`-conforming types: the tangent space is the
47754788
// `TangentVector` associated type.
47764789
auto *differentiableProtocol =
@@ -4993,6 +5006,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
49935006

49945007
// Compute the result linear map function type.
49955008
FunctionType *linearMapType;
5009+
// FIXME: Verify ExtInfo state is correct, not working by accident.
5010+
FunctionType::ExtInfo info;
49965011
switch (kind) {
49975012
case AutoDiffLinearMapKind::Differential: {
49985013
// Compute the differential type, returned by JVP functions.
@@ -5051,6 +5066,9 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
50515066
// Case 2: original function has wrt `inout` parameters.
50525067
// - Original: `(T0, inout T1, ...) -> R`
50535068
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
5069+
//
5070+
// Special case: yields. They act as parameters, so will
5071+
// always be on result side.
50545072
SmallVector<TupleTypeElt, 4> pullbackResults;
50555073
SmallVector<AnyFunctionType::Param, 2> semanticResultParams;
50565074
for (auto i : range(diffParams.size())) {
@@ -5073,34 +5091,42 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
50735091
}
50745092
pullbackResults.emplace_back(paramTan->getType());
50755093
}
5076-
Type pullbackResult;
5077-
if (pullbackResults.empty()) {
5078-
pullbackResult = ctx.TheEmptyTupleType;
5079-
} else if (pullbackResults.size() == 1) {
5080-
pullbackResult = pullbackResults.front().getType();
5081-
} else {
5082-
pullbackResult = TupleType::get(pullbackResults, ctx);
5083-
}
5084-
// First accumulate non-inout results as pullback parameters.
5094+
// First accumulate ordinary result (not-semantic result parameters) as
5095+
// pullback parameters.
50855096
SmallVector<FunctionType::Param, 2> pullbackParams;
50865097
for (auto i : range(resultTanTypes.size())) {
50875098
auto resultTanType = resultTanTypes[i];
50885099
auto flags = ParameterTypeFlags().withInOut(false);
5089-
pullbackParams.push_back(AnyFunctionType::Param(
5090-
resultTanType, Identifier(), flags));
5100+
if (resultTanType->is<YieldResultType>()) {
5101+
pullbackResults.emplace_back(resultTanType);
5102+
info = info.withCoroutine(true);
5103+
} else {
5104+
pullbackParams.push_back(AnyFunctionType::Param(
5105+
resultTanType, Identifier(), flags));
5106+
}
50915107
}
50925108
// Then append semantic result parameters.
50935109
for (auto i : range(semanticResultParams.size())) {
50945110
auto semanticResultParam = semanticResultParams[i];
50955111
auto semanticResultParamType = semanticResultParam.getPlainType();
50965112
auto semanticResultParamTan =
50975113
semanticResultParamType->getAutoDiffTangentSpace(lookupConformance);
5114+
assert(!semanticResultParamType->is<YieldResultType>() &&
5115+
"yields are always expected on result side");
50985116
auto flags = ParameterTypeFlags().withInOut(true);
50995117
pullbackParams.push_back(AnyFunctionType::Param(
51005118
semanticResultParamTan->getType(), Identifier(), flags));
51015119
}
5102-
// FIXME: Verify ExtInfo state is correct, not working by accident.
5103-
FunctionType::ExtInfo info;
5120+
5121+
Type pullbackResult;
5122+
if (pullbackResults.empty()) {
5123+
pullbackResult = ctx.TheEmptyTupleType;
5124+
} else if (pullbackResults.size() == 1) {
5125+
pullbackResult = pullbackResults.front().getType();
5126+
} else {
5127+
pullbackResult = TupleType::get(pullbackResults, ctx);
5128+
}
5129+
51045130
linearMapType = FunctionType::get(pullbackParams, pullbackResult, info);
51055131
break;
51065132
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,10 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
451451
fatalError()
452452
}
453453

454-
// expected-error @+1 {{cannot register derivative for _modify accessor}}
455454
@derivative(of: computedProperty._modify)
456-
mutating func vjpPropertyModify(_ newValue: T) -> (
457-
value: (), pullback: (inout TangentVector) -> T.TangentVector
455+
@yield_once
456+
mutating func vjpPropertyModify() -> (
457+
value: inout @yields T, pullback: @yield_once (inout TangentVector) -> inout @yields T.TangentVector
458458
) {
459459
fatalError()
460460
}

0 commit comments

Comments
 (0)