Skip to content

Commit ecf7747

Browse files
committed
autodiff: support the new array literal initialization pattern
1 parent 3309809 commit ecf7747

File tree

5 files changed

+213
-180
lines changed

5 files changed

+213
-180
lines changed

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "swift/SILOptimizer/Differentiation/Common.h"
1717

1818
#include "swift/Basic/Assertions.h"
19+
#include "swift/SIL/NodeDatastructures.h"
1920
#include "swift/SIL/Projection.h"
2021
#include "swift/SIL/SILArgument.h"
2122
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
@@ -435,67 +436,75 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
435436
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
436437
if (!dti)
437438
continue;
438-
// The second tuple field of the return value is the `RawPointer`.
439-
for (auto use : dti->getResult(1)->getUses()) {
440-
// The `RawPointer` passes through a `mark_dependence(pointer_to_address`.
441-
// That instruction's first use is a `store` whose source is useful; its
442-
// subsequent uses are `index_addr`s whose only use is a useful `store`.
443-
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
444-
assert(
445-
mdi &&
446-
"Expected a mark_dependence user for uninitialized array intrinsic.");
447-
auto *ptai = dyn_cast<PointerToAddressInst>(getSingleNonDebugUser(mdi));
448-
assert(ptai && "Expected a pointer_to_address.");
449-
setUseful(ptai, dependentVariableIndex);
450-
// Propagate usefulness through array element addresses:
451-
// `pointer_to_address` and `index_addr` instructions.
452-
//
453-
// - Set all array element addresses as useful.
454-
// - Find instructions with array element addresses as "result":
455-
// - `store` and `copy_addr` with array element address as destination.
456-
// - `apply` with array element address as an indirect result.
457-
// - For each instruction, propagate usefulness through "arguments":
458-
// - `store` and `copy_addr`: propagate to source.
459-
// - `apply`: propagate to arguments.
460-
//
461-
// NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used
462-
// because it marks more values than necessary as useful, including:
463-
// - The `RawPointer` result of the intrinsic.
464-
// - `integer_literal` operands to `index_addr` for indexing the
465-
// `RawPointer`.
466-
// It is also blocked by TF-1032: control flow differentiation crash for
467-
// active values with no tangent space.
468-
for (auto use : ptai->getUses()) {
469-
auto *user = use->getUser();
470-
if (auto *si = dyn_cast<StoreInst>(user)) {
471-
setUseful(si->getDest(), dependentVariableIndex);
472-
setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex);
473-
} else if (auto *cai = dyn_cast<CopyAddrInst>(user)) {
474-
setUseful(cai->getDest(), dependentVariableIndex);
475-
setUsefulAndPropagateToOperands(cai->getSrc(),
476-
dependentVariableIndex);
477-
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
478-
if (FullApplySite(ai).isIndirectResultOperand(*use))
479-
for (auto arg : ai->getArgumentsWithoutIndirectResults())
480-
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
481-
} else if (auto *iai = dyn_cast<IndexAddrInst>(user)) {
482-
setUseful(iai, dependentVariableIndex);
483-
for (auto use : iai->getUses()) {
484-
auto *user = use->getUser();
485-
if (auto si = dyn_cast<StoreInst>(user)) {
486-
setUseful(si->getDest(), dependentVariableIndex);
487-
setUsefulAndPropagateToOperands(si->getSrc(),
488-
dependentVariableIndex);
489-
} else if (auto *cai = dyn_cast<CopyAddrInst>(user)) {
490-
setUseful(cai->getDest(), dependentVariableIndex);
491-
setUsefulAndPropagateToOperands(cai->getSrc(),
492-
dependentVariableIndex);
493-
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
494-
if (FullApplySite(ai).isIndirectResultOperand(*use))
495-
for (auto arg : ai->getArgumentsWithoutIndirectResults())
496-
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
439+
440+
ValueWorklist worklist(dti->getResult(0));
441+
442+
while (SILValue v = worklist.pop()) {
443+
for (auto use : v->getUses()) {
444+
switch (use->getUser()->getKind()) {
445+
case SILInstructionKind::UncheckedRefCastInst:
446+
case SILInstructionKind::StructExtractInst:
447+
case SILInstructionKind::BeginBorrowInst:
448+
worklist.pushIfNotVisited(cast<SingleValueInstruction>(use->getUser()));
449+
break;
450+
case SILInstructionKind::RefTailAddrInst: {
451+
auto *rta = cast<RefTailAddrInst>(use->getUser());
452+
setUseful(rta, dependentVariableIndex);
453+
// Propagate usefulness through array element addresses:
454+
// `pointer_to_address` and `index_addr` instructions.
455+
//
456+
// - Set all array element addresses as useful.
457+
// - Find instructions with array element addresses as "result":
458+
// - `store` and `copy_addr` with array element address as destination.
459+
// - `apply` with array element address as an indirect result.
460+
// - For each instruction, propagate usefulness through "arguments":
461+
// - `store` and `copy_addr`: propagate to source.
462+
// - `apply`: propagate to arguments.
463+
//
464+
// NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used
465+
// because it marks more values than necessary as useful, including:
466+
// - The `RawPointer` result of the intrinsic.
467+
// - `integer_literal` operands to `index_addr` for indexing the
468+
// `RawPointer`.
469+
// It is also blocked by TF-1032: control flow differentiation crash for
470+
// active values with no tangent space.
471+
for (auto use : rta->getUses()) {
472+
auto *user = use->getUser();
473+
if (auto *si = dyn_cast<StoreInst>(user)) {
474+
setUseful(si->getDest(), dependentVariableIndex);
475+
setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex);
476+
} else if (auto *cai = dyn_cast<CopyAddrInst>(user)) {
477+
setUseful(cai->getDest(), dependentVariableIndex);
478+
setUsefulAndPropagateToOperands(cai->getSrc(),
479+
dependentVariableIndex);
480+
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
481+
if (FullApplySite(ai).isIndirectResultOperand(*use))
482+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
483+
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
484+
} else if (auto *iai = dyn_cast<IndexAddrInst>(user)) {
485+
setUseful(iai, dependentVariableIndex);
486+
for (auto use : iai->getUses()) {
487+
auto *user = use->getUser();
488+
if (auto si = dyn_cast<StoreInst>(user)) {
489+
setUseful(si->getDest(), dependentVariableIndex);
490+
setUsefulAndPropagateToOperands(si->getSrc(),
491+
dependentVariableIndex);
492+
} else if (auto *cai = dyn_cast<CopyAddrInst>(user)) {
493+
setUseful(cai->getDest(), dependentVariableIndex);
494+
setUsefulAndPropagateToOperands(cai->getSrc(),
495+
dependentVariableIndex);
496+
} else if (auto *ai = dyn_cast<ApplyInst>(user)) {
497+
if (FullApplySite(ai).isIndirectResultOperand(*use))
498+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
499+
setUsefulAndPropagateToOperands(arg, dependentVariableIndex);
500+
}
501+
}
502+
}
497503
}
504+
break;
498505
}
506+
default:
507+
break;
499508
}
500509
}
501510
}

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,31 @@ raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
3232
// Helpers
3333
//===----------------------------------------------------------------------===//
3434

35+
static SILValue getArrayValueOfElementAddress(SILValue v) {
36+
while (true) {
37+
switch (v->getKind()) {
38+
case ValueKind::IndexAddrInst:
39+
case ValueKind::RefTailAddrInst:
40+
case ValueKind::UncheckedRefCastInst:
41+
case ValueKind::StructExtractInst:
42+
case ValueKind::BeginBorrowInst:
43+
v = cast<SingleValueInstruction>(v)->getOperand(0);
44+
break;
45+
default:
46+
return v;
47+
}
48+
}
49+
}
50+
3551
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
36-
// Find the `pointer_to_address` result, peering through `index_addr`.
37-
auto *ptai = dyn_cast<PointerToAddressInst>(v);
38-
if (auto *iai = dyn_cast<IndexAddrInst>(v))
39-
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
40-
if (!ptai)
41-
return nullptr;
42-
auto *mdi = dyn_cast<MarkDependenceInst>(
43-
ptai->getOperand()->getDefiningInstruction());
44-
if (!mdi)
52+
SILValue arr = getArrayValueOfElementAddress(v);
53+
54+
auto *mvir = dyn_cast<MultipleValueInstructionResult>(arr);
55+
if (!mvir)
4556
return nullptr;
57+
4658
// Return the `array.uninitialized_intrinsic` application, if it exists.
47-
if (auto *dti = dyn_cast<DestructureTupleInst>(
48-
mdi->getValue()->getDefiningInstruction()))
59+
if (auto *dti = dyn_cast<DestructureTupleInst>(mvir->getParent()))
4960
return ArraySemanticsCall(dti->getOperand(),
5061
semantics::ARRAY_UNINITIALIZED_INTRINSIC);
5162
return nullptr;

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "swift/Basic/STLExtras.h"
3535
#include "swift/SIL/ApplySite.h"
3636
#include "swift/SIL/InstructionUtils.h"
37+
#include "swift/SIL/NodeDatastructures.h"
3738
#include "swift/SIL/Projection.h"
3839
#include "swift/SIL/TypeSubstCloner.h"
3940
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
@@ -3572,7 +3573,7 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
35723573
originalProjection->getDefiningInstruction());
35733574
bool isAllocateUninitializedArrayIntrinsicElementAddress =
35743575
ai && definingInst &&
3575-
(isa<PointerToAddressInst>(definingInst) ||
3576+
(isa<RefTailAddrInst>(definingInst) ||
35763577
isa<IndexAddrInst>(definingInst));
35773578
if (isAllocateUninitializedArrayIntrinsicElementAddress) {
35783579
// Get the array element index of the result address.
@@ -3755,9 +3756,10 @@ void PullbackCloner::Implementation::
37553756
// %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
37563757
// %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
37573758
// (%20, %21) = destructure_tuple %19
3758-
// %22 = mark_dependence %21 on %20
3759-
// %23 = pointer_to_address %22 to [strict] $*Float
3760-
// store %0 to [trivial] %23
3759+
// %22 = begin_borrow %20
3760+
// %23 = struct_extract %22, #Array.arrayBuffer
3761+
// %24 = ref_tail_addr %22
3762+
// store %0 to [trivial] %24
37613763
// function_ref _finalizeUninitializedArray<A>(_:)
37623764
// %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
37633765
// %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
@@ -3772,23 +3774,36 @@ void PullbackCloner::Implementation::
37723774
<< originalValue);
37733775
auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
37743776
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
3775-
for (auto use : dti->getResult(1)->getUses()) {
3776-
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
3777-
assert(mdi && "Expected mark_dependence user");
3778-
auto *ptai =
3779-
dyn_cast_or_null<PointerToAddressInst>(getSingleNonDebugUser(mdi));
3780-
assert(ptai && "Expected pointer_to_address user");
3781-
auto adjBuf = getAdjointBuffer(origBB, ptai);
3782-
auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
3783-
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3784-
for (auto use : ptai->getUses()) {
3785-
if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) {
3786-
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
3787-
auto eltIndex = ili->getValue().getLimitedValue();
3788-
auto adjBuf = getAdjointBuffer(origBB, iai);
3789-
auto *eltAdjBuf =
3790-
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc);
3791-
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3777+
3778+
ValueWorklist worklist(dti->getResult(0));
3779+
3780+
while (SILValue v = worklist.pop()) {
3781+
for (auto use : v->getUses()) {
3782+
switch (use->getUser()->getKind()) {
3783+
case SILInstructionKind::UncheckedRefCastInst:
3784+
case SILInstructionKind::StructExtractInst:
3785+
case SILInstructionKind::BeginBorrowInst:
3786+
worklist.pushIfNotVisited(cast<SingleValueInstruction>(use->getUser()));
3787+
break;
3788+
case SILInstructionKind::RefTailAddrInst: {
3789+
auto *rta = cast<RefTailAddrInst>(use->getUser());
3790+
auto adjBuf = getAdjointBuffer(origBB, rta);
3791+
auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
3792+
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3793+
for (auto use : rta->getUses()) {
3794+
if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) {
3795+
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
3796+
auto eltIndex = ili->getValue().getLimitedValue();
3797+
auto adjBuf = getAdjointBuffer(origBB, iai);
3798+
auto *eltAdjBuf =
3799+
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc);
3800+
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
3801+
}
3802+
}
3803+
break;
3804+
}
3805+
default:
3806+
break;
37923807
}
37933808
}
37943809
}

0 commit comments

Comments
 (0)