Skip to content

Commit 186f88b

Browse files
authored
[AutoDiff] Fix two issues related with emission of differentiability witnesses (#80983)
1. When differentiable nested function (closure) is specialized by capture promotion pass ensure we generate a differentiability witness for the specialized function as well. Ensure the original witness is removed if the original function becomes dead. 2. Differentiability witnesses for a function could originate either from its `@differentiable` attribute or from explicit `@derivative(of:)` attribute on the derivative. In the latter case the derivative itself might not be emitted, while original function is (e.g. original function is `@inlineable`, but derivative is `@usableFromInline`). Previously both cases were handled only when function body was emitted. As a result we missed witness in the aforementioned case. Ensure the differentiability witness originating from `@derivative(of:)` is emitted even if we're not going to emit body of the derivative. Fixes #59135
1 parent 49fdf83 commit 186f88b

File tree

7 files changed

+164
-29
lines changed

7 files changed

+164
-29
lines changed

include/swift/SIL/SILModule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,12 @@ class SILModule {
591591
/// Erase a global SIL variable from the module.
592592
void eraseGlobalVariable(SILGlobalVariable *G);
593593

594+
/// Erase a differentiability witness from the module.
595+
void eraseDifferentiabilityWitness(SILDifferentiabilityWitness *dw);
596+
597+
/// Erase all differentiability witnesses for function f.
598+
void eraseAllDifferentiabilityWitnesses(SILFunction *f);
599+
594600
/// Create and return an empty SIL module suitable for generating or parsing
595601
/// SIL into.
596602
///

lib/SIL/IR/SILModule.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,33 @@ void SILModule::eraseGlobalVariable(SILGlobalVariable *gv) {
504504
getSILGlobalList().erase(gv);
505505
}
506506

507+
void SILModule::eraseDifferentiabilityWitness(SILDifferentiabilityWitness *dw) {
508+
getSILLoader()->invalidateDifferentiabilityWitness(dw);
509+
510+
Mangle::ASTMangler mangler(getASTContext());
511+
auto originalFunction = dw->getOriginalFunction()->getName();
512+
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
513+
originalFunction, dw->getKind(), dw->getConfig());
514+
DifferentiabilityWitnessMap.erase(mangledKey);
515+
llvm::erase(DifferentiabilityWitnessesByFunction[originalFunction], dw);
516+
517+
getDifferentiabilityWitnessList().erase(dw);
518+
}
519+
520+
void SILModule::eraseAllDifferentiabilityWitnesses(SILFunction *f) {
521+
Mangle::ASTMangler mangler(getASTContext());
522+
523+
for (auto *dw : DifferentiabilityWitnessesByFunction.at(f->getName())) {
524+
getSILLoader()->invalidateDifferentiabilityWitness(dw);
525+
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
526+
f->getName(), dw->getKind(), dw->getConfig());
527+
DifferentiabilityWitnessMap.erase(mangledKey);
528+
getDifferentiabilityWitnessList().erase(dw);
529+
}
530+
531+
DifferentiabilityWitnessesByFunction.erase(f->getName());
532+
}
533+
507534
SILVTable *SILModule::lookUpVTable(const ClassDecl *C,
508535
bool deserializeLazily) {
509536
if (!C)

lib/SILGen/SILGen.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
14091409

14101410
void SILGenModule::emitDifferentiabilityWitnessesForFunction(
14111411
SILDeclRef constant, SILFunction *F) {
1412-
// Visit `@derivative` attributes and generate SIL differentiability
1412+
// Visit `@differentiable` attributes and generate SIL differentiability
14131413
// witnesses.
14141414
// Skip if the SILDeclRef is a:
14151415
// - Default argument generator function.
@@ -1439,33 +1439,6 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
14391439
config, /*jvp*/ nullptr,
14401440
/*vjp*/ nullptr, diffAttr);
14411441
}
1442-
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
1443-
SILFunction *jvp = nullptr;
1444-
SILFunction *vjp = nullptr;
1445-
switch (derivAttr->getDerivativeKind()) {
1446-
case AutoDiffDerivativeFunctionKind::JVP:
1447-
jvp = F;
1448-
break;
1449-
case AutoDiffDerivativeFunctionKind::VJP:
1450-
vjp = F;
1451-
break;
1452-
}
1453-
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
1454-
auto origDeclRef =
1455-
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
1456-
auto *origFn = getFunction(origDeclRef, NotForDefinition);
1457-
auto witnessGenSig =
1458-
autodiff::getDifferentiabilityWitnessGenericSignature(
1459-
origAFD->getGenericSignature(), AFD->getGenericSignature());
1460-
auto *resultIndices =
1461-
autodiff::getFunctionSemanticResultIndices(origAFD,
1462-
derivAttr->getParameterIndices());
1463-
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
1464-
witnessGenSig);
1465-
emitDifferentiabilityWitness(origAFD, origFn,
1466-
DifferentiabilityKind::Reverse, config, jvp,
1467-
vjp, derivAttr);
1468-
}
14691442
};
14701443
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
14711444
if (accessor->isGetter())
@@ -1582,6 +1555,43 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
15821555
SILDeclRef::BackDeploymentKind::Thunk);
15831556
emitBackDeploymentThunk(thunk);
15841557
}
1558+
1559+
// Emit differentiability witness for the function referenced in
1560+
// @derivative(of:) attribute registering current function as VJP / JVP.
1561+
// Differentiability witnesses for a function could originate either from its
1562+
// @differentiable attribute or from explicit @derivative(of:) attribute on
1563+
// the derivative. In the latter case the derivative itself might not be
1564+
// emitted, while original function is (e.g. original function is @inlineable,
1565+
// but derivative is @usableFromInline). Ensure the differentiability witness
1566+
// originating from @derivative(of:) is emitted even if we're not going to
1567+
// emit body of the derivative.
1568+
for (auto *derivAttr : AFD->getAttrs().getAttributes<DerivativeAttr>()) {
1569+
auto *f = getFunction(SILDeclRef(AFD), NotForDefinition);
1570+
SILFunction *jvp = nullptr, *vjp = nullptr;
1571+
switch (derivAttr->getDerivativeKind()) {
1572+
case AutoDiffDerivativeFunctionKind::JVP:
1573+
jvp = f;
1574+
break;
1575+
case AutoDiffDerivativeFunctionKind::VJP:
1576+
vjp = f;
1577+
break;
1578+
}
1579+
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
1580+
auto origDeclRef =
1581+
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
1582+
auto *origFn = getFunction(origDeclRef, NotForDefinition);
1583+
auto witnessGenSig =
1584+
autodiff::getDifferentiabilityWitnessGenericSignature(
1585+
origAFD->getGenericSignature(), AFD->getGenericSignature());
1586+
auto *resultIndices =
1587+
autodiff::getFunctionSemanticResultIndices(origAFD,
1588+
derivAttr->getParameterIndices());
1589+
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
1590+
witnessGenSig);
1591+
emitDifferentiabilityWitness(origAFD, origFn,
1592+
DifferentiabilityKind::Reverse, config, jvp,
1593+
vjp, derivAttr);
1594+
}
15851595
}
15861596

15871597
void SILGenModule::emitFunction(FuncDecl *fd) {

lib/SILOptimizer/Mandatory/CapturePromotion.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,9 +1473,25 @@ processPartialApplyInst(SILOptFunctionBuilder &funcBuilder,
14731473
funcBuilder, pai, fri, promotableIndices, f->getResilienceExpansion());
14741474
worklist.push_back(clonedFn);
14751475

1476+
SILFunction *origFn = fri->getReferencedFunction();
1477+
for (const auto *w : mod.lookUpDifferentiabilityWitnessesForFunction(
1478+
origFn->getName())) {
1479+
// @derivative(of:) attribute could only be applied at global scope, therefore
1480+
// local functions might not have custom derivatives registered
1481+
assert(!w->getJVP() && !w->getVJP() && "does not expect custom derivatives here");
1482+
auto linkage = stripExternalFromLinkage(clonedFn->getLinkage());
1483+
SILDifferentiabilityWitness::createDefinition(
1484+
mod, linkage, clonedFn,
1485+
w->getKind(), w->getParameterIndices(), w->getResultIndices(),
1486+
w->getDerivativeGenericSignature(),
1487+
/*jvp*/ nullptr, /*vjp*/ nullptr,
1488+
/*isSerialized*/ hasPublicVisibility(clonedFn->getLinkage()),
1489+
w->getAttribute());
1490+
}
1491+
14761492
// Mark the original partial apply function as deletable if it doesn't have
14771493
// uses later.
1478-
fri->getReferencedFunction()->addSemanticsAttr(semantics::DELETE_IF_UNUSED);
1494+
origFn->addSemanticsAttr(semantics::DELETE_IF_UNUSED);
14791495

14801496
// Initialize a SILBuilder and create a function_ref referencing the cloned
14811497
// closure.

lib/SILOptimizer/Mandatory/DiagnosticDeadFunctionElimination.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace {
3636
struct DiagnosticDeadFunctionEliminator : SILFunctionTransform {
3737
void run() override {
3838
auto *fn = getFunction();
39+
auto &mod = fn->getModule();
3940

4041
// If an earlier pass asked us to eliminate the function body if it's
4142
// unused, and the function is in fact unused, do that now.
@@ -67,6 +68,10 @@ struct DiagnosticDeadFunctionEliminator : SILFunctionTransform {
6768
b.createUnreachable(loc);
6869
}
6970

71+
// Drop differentiability witnesses, if any
72+
if (!mod.lookUpDifferentiabilityWitnessesForFunction(fn->getName()).empty())
73+
mod.eraseAllDifferentiabilityWitnesses(fn);
74+
7075
// If the function has shared linkage, reduce this version to private
7176
// linkage, because we don't want the deleted-body form to win in any
7277
// ODR shootouts.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -o /dev/null 2>&1 %s | %FileCheck %s
2+
3+
// The differentiability witness for y in s(h:) will be generated by silgen. However, later the capture
4+
// promotion pass would specialize it since it only captures an integer and therefore does not need to
5+
// box the capture. Ensure we create differentiability witness for specialized function. In addition to
6+
// this, since the original function is not used anymore, the body of it is removed (with only unreachable
7+
// terminator inside). Remove original differentiability witness as it would lead to non-differentiable
8+
// diagnostics further on.
9+
10+
// CHECK-LABEL: differentiability witness for specialized y #1 (_:) in s(h:)
11+
// CHECK: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s{{.*}} : $@convention(thin) (@guaranteed W, Int) -> @owned W {
12+
// CHECK-NOT: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s{{.*}} : $@convention(thin) (@guaranteed W, @guaranteed { var Int }) -> @owned W {
13+
14+
import _Differentiation
15+
struct B: Differentiable{}
16+
struct X { var j = [Float]()}
17+
struct W: Differentiable {
18+
@noDerivative var z: X
19+
var h: B
20+
}
21+
func o<T, R>(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {f(x)}
22+
func m<T, R>(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {{ x in o(x, f) }}
23+
@differentiable(reverse)
24+
func s(h: B) -> B {
25+
var (_, e) = (0,0)
26+
@differentiable(reverse)
27+
func y(_ i: W) -> W {
28+
let _ = e;
29+
return i
30+
}
31+
let w = m(y)
32+
return B()
33+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-emit-sil -emit-module -module-name M -emit-module-path %t/M.swiftmodule 2>&1 %s | %FileCheck %s
3+
4+
// The original function Tensor.subscriptIndexPath() is not marked as @differentiable. As a result, no explicit differentiable witness is generated for it.
5+
// However, the witness is generated as a side effect of providing a derivative via @derivative(of: subscriptIndexPath) on _vjpSubscriptIndexPath.
6+
// Since _vjpSubscriptIndexPath is not emitted when -emit-module is used, we need to ensure we still generate a witness.
7+
8+
import _Differentiation
9+
10+
// CHECK-LABEL: differentiability witness for Tensor.subscriptIndexPath()
11+
// CHECK: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s1M6TensorV18subscriptIndexPathACyF : $@convention(method) (Tensor) -> Tensor {
12+
// CHECK: vjp: @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)
13+
14+
// CHECK-LABEL: reverse-mode derivative of Tensor.subscriptIndexPath()
15+
// CHECK: @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) {
16+
// CHECK: function_ref Tensor._vjpSubscriptIndexPath()
17+
// CHECK: function_ref @$s1M6TensorV22_vjpSubscriptIndexPathAC5value_A2Cc8pullbacktyF : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)
18+
19+
public struct Tensor: Differentiable & AdditiveArithmetic {
20+
@inlinable
21+
func subscriptIndexPath() -> Tensor {
22+
fatalError()
23+
}
24+
25+
@inlinable
26+
@differentiable(reverse, wrt: self)
27+
func subscriptRanges() -> Tensor {
28+
subscriptIndexPath()
29+
}
30+
31+
@usableFromInline
32+
@derivative(of: subscriptIndexPath)
33+
func _vjpSubscriptIndexPath() -> (
34+
value: Tensor, pullback: (Tensor) -> Tensor
35+
) {
36+
fatalError()
37+
}
38+
}

0 commit comments

Comments
 (0)