@@ -2085,37 +2085,38 @@ bool PullbackCloner::Implementation::run() {
20852085
20862086 // Collect differentiation parameter adjoints.
20872087 // Do a first pass to collect non-inout values.
2088- unsigned pullbackInoutArgumentIndex = 0 ;
20892088 for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2090- auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2091- if (!isParameterInout) {
2089+ if (!conv.getParameters ()[i].isIndirectMutating ()) {
2090+ addRetElt (i);
2091+ }
2092+ }
2093+
2094+ // Do a second pass for all inout parameters, however this is only necessary
2095+ // for functions with multiple basic blocks. For functions with a single
2096+ // basic block adjoint accumulation for those parameters is already done by
2097+ // per-instruction visitors.
2098+ if (getOriginal ().size () > 1 ) {
2099+ const auto &pullbackConv = pullback.getConventions ();
2100+ SmallVector<SILArgument *, 1 > pullbackInOutArgs;
2101+ for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults ())) {
2102+ if (pullbackConv.getParameters ()[pullbackArg.index ()].isIndirectMutating ())
2103+ pullbackInOutArgs.push_back (pullbackArg.value ());
2104+ }
2105+
2106+ unsigned pullbackInoutArgumentIdx = 0 ;
2107+ for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2108+ // Skip non-inout parameters.
2109+ if (!conv.getParameters ()[i].isIndirectMutating ())
2110+ continue ;
2111+
2112+ // For functions with multiple basic blocks, accumulation is needed
2113+ // for `inout` parameters because pullback basic blocks have different
2114+ // adjoint buffers.
2115+ pullbackIndirectResults.push_back (pullbackInOutArgs[pullbackInoutArgumentIdx++]);
20922116 addRetElt (i);
20932117 }
20942118 }
20952119
2096- // Do a second pass for all inout parameters.
2097- for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2098- // Skip non-inout parameters.
2099- auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2100- if (!isParameterInout)
2101- continue ;
2102-
2103- // Skip `inout` parameters for functions with a single basic block:
2104- // adjoint accumulation for those parameters is already done by
2105- // per-instruction visitors.
2106- if (getOriginal ().size () == 1 )
2107- continue ;
2108-
2109- // For functions with multiple basic blocks, accumulation is needed
2110- // for `inout` parameters because pullback basic blocks have different
2111- // adjoint buffers.
2112- auto pullbackInoutArgument =
2113- getPullback ()
2114- .getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++];
2115- pullbackIndirectResults.push_back (pullbackInoutArgument);
2116- addRetElt (i);
2117- }
2118-
21192120 // Copy them to adjoint indirect results.
21202121 assert (indParamAdjoints.size () == pullbackIndirectResults.size () &&
21212122 " Indirect parameter adjoint count mismatch" );
0 commit comments