@@ -2515,12 +2515,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25152515
25162516 // Get predecessor terminator operands.
25172517 SmallVector<std::pair<SILBasicBlock *, SILValue>, 4 > incomingValues;
2518- bbArg->getSingleTerminatorOperands (incomingValues);
2519-
2520- // Returns true if the given terminator instruction is a `switch_enum` on
2521- // an `Optional`-typed value. `switch_enum` instructions require
2522- // special-case adjoint value propagation for the operand.
2523- auto isSwitchEnumInstOnOptional =
2518+ if (bbArg->getSingleTerminatorOperands (incomingValues)) {
2519+ // Returns true if the given terminator instruction is a `switch_enum` on
2520+ // an `Optional`-typed value. `switch_enum` instructions require
2521+ // special-case adjoint value propagation for the operand.
2522+ auto isSwitchEnumInstOnOptional =
25242523 [&ctx = getASTContext ()](TermInst *termInst) {
25252524 if (!termInst)
25262525 return false ;
@@ -2531,49 +2530,52 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25312530 return false ;
25322531 };
25332532
2534- // Check the tangent value category of the active basic block argument.
2535- switch (getTangentValueCategory (bbArg)) {
2536- // If argument has a loadable tangent value category: materialize adjoint
2537- // value of the argument, create a copy, and set the copy as the adjoint
2538- // value of incoming values.
2539- case SILValueCategory::Object: {
2540- auto bbArgAdj = getAdjointValue (bb, bbArg);
2541- auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2542- auto concreteBBArgAdjCopy =
2533+ // Check the tangent value category of the active basic block argument.
2534+ switch (getTangentValueCategory (bbArg)) {
2535+ // If argument has a loadable tangent value category: materialize adjoint
2536+ // value of the argument, create a copy, and set the copy as the adjoint
2537+ // value of incoming values.
2538+ case SILValueCategory::Object: {
2539+ auto bbArgAdj = getAdjointValue (bb, bbArg);
2540+ auto concreteBBArgAdj = materializeAdjointDirect (bbArgAdj, pbLoc);
2541+ auto concreteBBArgAdjCopy =
25432542 builder.emitCopyValueOperation (pbLoc, concreteBBArgAdj);
2544- for (auto pair : incomingValues) {
2545- auto *predBB = std::get<0 >(pair);
2546- auto incomingValue = std::get<1 >(pair);
2547- // Handle `switch_enum` on `Optional`.
2548- auto termInst = bbArg->getSingleTerminator ();
2549- if (isSwitchEnumInstOnOptional (termInst)) {
2550- accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2551- } else {
2552- blockTemporaries[getPullbackBlock (predBB)].insert (
2543+ for (auto pair : incomingValues) {
2544+ pair.second ->dump ();
2545+ auto *predBB = std::get<0 >(pair);
2546+ auto incomingValue = std::get<1 >(pair);
2547+ // Handle `switch_enum` on `Optional`.
2548+ auto termInst = bbArg->getSingleTerminator ();
2549+ if (isSwitchEnumInstOnOptional (termInst)) {
2550+ accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2551+ } else {
2552+ blockTemporaries[getPullbackBlock (predBB)].insert (
25532553 concreteBBArgAdjCopy);
2554- setAdjointValue (predBB, incomingValue,
2555- makeConcreteAdjointValue (concreteBBArgAdjCopy));
2554+ setAdjointValue (predBB, incomingValue,
2555+ makeConcreteAdjointValue (concreteBBArgAdjCopy));
2556+ }
25562557 }
2558+ break ;
25572559 }
2558- break ;
2559- }
2560- // If argument has an address tangent value category: materialize adjoint
2561- // value of the argument, create a copy, and set the copy as the adjoint
2562- // value of incoming values.
2563- case SILValueCategory::Address: {
2564- auto bbArgAdjBuf = getAdjointBuffer (bb, bbArg );
2565- for ( auto pair : incomingValues) {
2566- auto incomingValue = std::get< 1 >(pair );
2567- // Handle `switch_enum` on `Optional`.
2568- auto termInst = bbArg-> getSingleTerminator ( );
2569- if ( isSwitchEnumInstOnOptional (termInst))
2570- accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf);
2571- else
2572- addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc) ;
2560+ // If argument has an address tangent value category: materialize adjoint
2561+ // value of the argument, create a copy, and set the copy as the adjoint
2562+ // value of incoming values.
2563+ case SILValueCategory::Address: {
2564+ auto bbArgAdjBuf = getAdjointBuffer (bb, bbArg);
2565+ for ( auto pair : incomingValues) {
2566+ auto incomingValue = std::get< 1 >(pair );
2567+ // Handle `switch_enum` on `Optional`.
2568+ auto termInst = bbArg-> getSingleTerminator ( );
2569+ if ( isSwitchEnumInstOnOptional (termInst))
2570+ accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf );
2571+ else
2572+ addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc );
2573+ }
2574+ break ;
25732575 }
2574- break ;
2575- }
2576- }
2576+ }
2577+ } else
2578+ llvm::report_fatal_error ( " do not know how to handle this incoming bb argument " );
25772579 }
25782580
25792581 // 3. Build the pullback successor cases for the `switch_enum`
0 commit comments