Skip to content

Commit 89a7663

Browse files
authored
[AutoDiff] Initial support for differentiation of throwing functions (#82653)
This adds initial support for differentiation of functions that may produce `Error` result. Essentially we wrap the pullback into `Optional` and emit a diamond-shape control flow pattern depending on whether the pullback value is available or not. VJP emission was modified to accommodate for this. In addition to this, some additional tricks are required as `try_apply` result is not available in the instruction parent block, it is available in normal successor basic block. As a result we can now: - differentiate an active `try_apply` result (that would be produced from `do ... try .. catch` constructions) - `try_apply` when error result is unreachable (usually `try!` and similar source code constructs) - Support (some) throwing functions with builtin differentiation operators. stdlib change will follow. Though we cannot support typed throws here (yet) - Correctly propagate error types during currying around differentiable functions as well as type-checking for `@derivative(of:)` attribute, so we can register custom derivatives for functions producing error result - Added custom derivative for `Optional.??` operator (note that support here is not yet complete as we cannot differentiate through autoclosures, so `x ?? y` works only if `y` is not active, e.g. a constant value). Some fixes here and there
1 parent b7a07db commit 89a7663

File tree

14 files changed

+725
-74
lines changed

14 files changed

+725
-74
lines changed

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ inline void createEntryArguments(SILFunction *f) {
277277
indResTy = indResTy.mapTypeOutOfContext();
278278
createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType());
279279
}
280+
if (auto indErrorResTy =
281+
conv.getIndirectErrorResultType(f->getTypeExpansionContext())) {
282+
if (indErrorResTy.hasArchetype())
283+
indErrorResTy = indErrorResTy.mapTypeOutOfContext();
284+
createFunctionArgument(
285+
f->mapTypeIntoContext(indErrorResTy).getAddressType());
286+
}
287+
280288
for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) {
281289
if (paramTy.hasArchetype())
282290
paramTy = paramTy.mapTypeOutOfContext();

lib/AST/Builtins.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,6 +2920,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
29202920
if (!autodiff::getBuiltinApplyDerivativeConfig(
29212921
OperationName, kind, arity, throws))
29222922
return nullptr;
2923+
// TODO: Support somehow typed throws
29232924
return getAutoDiffApplyDerivativeFunction(Context, Id, kind, arity,
29242925
throws, /*thrownType=*/Type());
29252926
}
@@ -2929,6 +2930,7 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
29292930
if (!autodiff::getBuiltinApplyTransposeConfig(
29302931
OperationName, arity, throws))
29312932
return nullptr;
2933+
// TODO: Support somehow typed throws
29322934
return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws,
29332935
/*thrownType=*/Type());
29342936
}

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,12 +1190,9 @@ static ManagedValue emitBuiltinTypeTrait(SILGenFunction &SGF,
11901190
}
11911191

11921192
static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
1193-
AutoDiffDerivativeFunctionKind kind, unsigned arity,
1194-
bool throws, SILGenFunction &SGF, SILLocation loc,
1195-
SubstitutionMap substitutions, ArrayRef<ManagedValue> args, SGFContext C) {
1196-
// FIXME(https://github.com/apple/swift/issues/54259): Support throwing functions.
1197-
assert(!throws && "Throwing functions are not yet supported");
1198-
1193+
AutoDiffDerivativeFunctionKind kind, unsigned arity, SILGenFunction &SGF,
1194+
SILLocation loc, SubstitutionMap substitutions, ArrayRef<ManagedValue> args,
1195+
SGFContext C) {
11991196
auto origFnVal = args[0];
12001197
SmallVector<SILValue, 2> origFnArgVals;
12011198
for (auto& arg : args.drop_front(1))
@@ -1213,7 +1210,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12131210
origFnVal = SGF.B.createBeginBorrow(loc, origFnVal);
12141211
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
12151212
loc, kind, origFnVal.getValue());
1216-
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
1213+
SILType derivativeType = derivativeFn->getType();
1214+
auto derivativeFnType = derivativeType.castTo<SILFunctionType>();
12171215
assert(derivativeFnType->getNumResults() == 2);
12181216
assert(derivativeFnType->getNumParameters() == origFnArgVals.size());
12191217

@@ -1240,8 +1238,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12401238
applyArgs.push_back(SGF.B.createTupleElementAddr(loc, indResBuffer, 0));
12411239
for (auto origFnArgVal : origFnArgVals)
12421240
applyArgs.push_back(origFnArgVal);
1243-
auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(),
1244-
applyArgs);
1241+
auto differential = SGF.emitApplyWithRethrow(
1242+
loc, derivativeFn, derivativeType, SubstitutionMap(), applyArgs);
12451243

12461244
derivativeFn = SILValue();
12471245

@@ -1253,8 +1251,8 @@ static ManagedValue emitBuiltinAutoDiffApplyDerivativeFunction(
12531251
}
12541252

12551253
// Do the apply for the direct result case.
1256-
auto resultTuple = SGF.B.createApply(
1257-
loc, derivativeFn, SubstitutionMap(), origFnArgVals);
1254+
auto resultTuple = SGF.emitApplyWithRethrow(loc, derivativeFn, derivativeType,
1255+
SubstitutionMap(), origFnArgVals);
12581256

12591257
derivativeFn = SILValue();
12601258

@@ -1323,8 +1321,8 @@ static ManagedValue emitBuiltinApplyDerivative(
13231321
auto successfullyParsed = autodiff::getBuiltinApplyDerivativeConfig(
13241322
builtinName, kind, arity, throws);
13251323
assert(successfullyParsed);
1326-
return emitBuiltinAutoDiffApplyDerivativeFunction(
1327-
kind, arity, throws, SGF, loc, substitutions, args, C);
1324+
return emitBuiltinAutoDiffApplyDerivativeFunction(kind, arity, SGF, loc,
1325+
substitutions, args, C);
13281326
}
13291327

13301328
static ManagedValue emitBuiltinApplyTranspose(

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,15 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
146146
heapAllocatedContext = true;
147147
decl->setInterfaceType(astCtx.TheRawPointerType);
148148
} else { // Otherwise the payload is the linear map tuple.
149-
auto *linearMapStructTy = getLinearMapTupleType(predBB);
149+
auto *linearMapTupleTy = getLinearMapTupleType(predBB);
150150
// Do not create entries for unreachable predecessors
151-
if (!linearMapStructTy)
151+
if (!linearMapTupleTy)
152152
continue;
153-
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
154-
decl->setInterfaceType(
155-
canLinearMapStructTy->hasArchetype()
156-
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
153+
154+
auto canLinearMapTupleTy = linearMapTupleTy->getCanonicalType();
155+
decl->setInterfaceType(canLinearMapTupleTy->hasArchetype()
156+
? canLinearMapTupleTy->mapTypeOutOfContext()
157+
: canLinearMapTupleTy);
157158
}
158159
// Create enum element and enum case declarations.
159160
auto *paramList = ParameterList::create(astCtx, {decl});
@@ -185,6 +186,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
185186
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
186187
return activityInfo.isActive(res, config);
187188
});
189+
188190
bool hasActiveSemanticResultArgument = false;
189191
bool hasActiveArguments = false;
190192
auto numIndirectResults = fai.getNumIndirectSILResults();
@@ -313,10 +315,12 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
313315
params, silFnTy->getAllResultsInterfaceType().getASTType(), info);
314316
}
315317

316-
if (astFnTy->hasArchetype())
317-
return astFnTy->mapTypeOutOfContext();
318+
Type resultType =
319+
astFnTy->hasArchetype() ? astFnTy->mapTypeOutOfContext() : astFnTy;
320+
if (fai.getKind() == FullApplySiteKind::TryApplyInst)
321+
resultType = resultType->wrapInOptionalType();
318322

319-
return astFnTy;
323+
return resultType;
320324
}
321325

322326
void LinearMapInfo::generateDifferentiationDataStructures(

0 commit comments

Comments
 (0)