@@ -3666,14 +3666,31 @@ void PullbackCloner::Implementation::
36663666 if (originalValue != dti->getResult (0 ))
36673667 return ;
36683668 // Accumulate the array's adjoint value into the adjoint buffers of its
3669- // element addresses: `pointer_to_address` and `index_addr` instructions.
3669+ // element addresses: `pointer_to_address` and (optionally) `index_addr`
3670+ // instructions.
3671+ // The input code looks like as follows:
3672+ // %17 = integer_literal $Builtin.Word, 1
3673+ // function_ref _allocateUninitializedArray<A>(_:)
3674+ // %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3675+ // %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3676+ // (%20, %21) = destructure_tuple %19
3677+ // %22 = mark_dependence %21 on %20
3678+ // %23 = pointer_to_address %22 to [strict] $*Float
3679+ // store %0 to [trivial] %23
3680+ // function_ref _finalizeUninitializedArray<A>(_:)
3681+ // %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
3682+ // %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
3683+ // Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above.
3684+ // We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
3685+ // Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
3686+ // If we're initializing multiple values we're having additional `index_addr` instructions, but
3687+ // the handling is similar.
36703688 LLVM_DEBUG (getADDebugStream ()
36713689 << " Accumulating adjoint value for array literal into element "
36723690 " address adjoint buffers"
36733691 << originalValue);
36743692 auto arrayAdjoint = materializeAdjointDirect (arrayAdjointValue, loc);
36753693 builder.setCurrentDebugScope (remapScope (dti->getDebugScope ()));
3676- builder.setInsertionPoint (arrayAdjoint->getParentBlock ());
36773694 for (auto use : dti->getResult (1 )->getUses ()) {
36783695 auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser ());
36793696 assert (mdi && " Expected mark_dependence user" );
0 commit comments