|
16 | 16 | #include "swift/SILOptimizer/Differentiation/Common.h" |
17 | 17 |
|
18 | 18 | #include "swift/Basic/Assertions.h" |
| 19 | +#include "swift/SIL/NodeDatastructures.h" |
19 | 20 | #include "swift/SIL/Projection.h" |
20 | 21 | #include "swift/SIL/SILArgument.h" |
21 | 22 | #include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" |
@@ -435,67 +436,75 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization( |
435 | 436 | auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
436 | 437 | if (!dti) |
437 | 438 | continue; |
438 | | - // The second tuple field of the return value is the `RawPointer`. |
439 | | - for (auto use : dti->getResult(1)->getUses()) { |
440 | | - // The `RawPointer` passes through a `mark_dependence(pointer_to_address`. |
441 | | - // That instruction's first use is a `store` whose source is useful; its |
442 | | - // subsequent uses are `index_addr`s whose only use is a useful `store`. |
443 | | - auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser()); |
444 | | - assert( |
445 | | - mdi && |
446 | | - "Expected a mark_dependence user for uninitialized array intrinsic."); |
447 | | - auto *ptai = dyn_cast<PointerToAddressInst>(getSingleNonDebugUser(mdi)); |
448 | | - assert(ptai && "Expected a pointer_to_address."); |
449 | | - setUseful(ptai, dependentVariableIndex); |
450 | | - // Propagate usefulness through array element addresses: |
451 | | - // `pointer_to_address` and `index_addr` instructions. |
452 | | - // |
453 | | - // - Set all array element addresses as useful. |
454 | | - // - Find instructions with array element addresses as "result": |
455 | | - // - `store` and `copy_addr` with array element address as destination. |
456 | | - // - `apply` with array element address as an indirect result. |
457 | | - // - For each instruction, propagate usefulness through "arguments": |
458 | | - // - `store` and `copy_addr`: propagate to source. |
459 | | - // - `apply`: propagate to arguments. |
460 | | - // |
461 | | - // NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used |
462 | | - // because it marks more values than necessary as useful, including: |
463 | | - // - The `RawPointer` result of the intrinsic. |
464 | | - // - `integer_literal` operands to `index_addr` for indexing the |
465 | | - // `RawPointer`. |
466 | | - // It is also blocked by TF-1032: control flow differentiation crash for |
467 | | - // active values with no tangent space. |
468 | | - for (auto use : ptai->getUses()) { |
469 | | - auto *user = use->getUser(); |
470 | | - if (auto *si = dyn_cast<StoreInst>(user)) { |
471 | | - setUseful(si->getDest(), dependentVariableIndex); |
472 | | - setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex); |
473 | | - } else if (auto *cai = dyn_cast<CopyAddrInst>(user)) { |
474 | | - setUseful(cai->getDest(), dependentVariableIndex); |
475 | | - setUsefulAndPropagateToOperands(cai->getSrc(), |
476 | | - dependentVariableIndex); |
477 | | - } else if (auto *ai = dyn_cast<ApplyInst>(user)) { |
478 | | - if (FullApplySite(ai).isIndirectResultOperand(*use)) |
479 | | - for (auto arg : ai->getArgumentsWithoutIndirectResults()) |
480 | | - setUsefulAndPropagateToOperands(arg, dependentVariableIndex); |
481 | | - } else if (auto *iai = dyn_cast<IndexAddrInst>(user)) { |
482 | | - setUseful(iai, dependentVariableIndex); |
483 | | - for (auto use : iai->getUses()) { |
484 | | - auto *user = use->getUser(); |
485 | | - if (auto si = dyn_cast<StoreInst>(user)) { |
486 | | - setUseful(si->getDest(), dependentVariableIndex); |
487 | | - setUsefulAndPropagateToOperands(si->getSrc(), |
488 | | - dependentVariableIndex); |
489 | | - } else if (auto *cai = dyn_cast<CopyAddrInst>(user)) { |
490 | | - setUseful(cai->getDest(), dependentVariableIndex); |
491 | | - setUsefulAndPropagateToOperands(cai->getSrc(), |
492 | | - dependentVariableIndex); |
493 | | - } else if (auto *ai = dyn_cast<ApplyInst>(user)) { |
494 | | - if (FullApplySite(ai).isIndirectResultOperand(*use)) |
495 | | - for (auto arg : ai->getArgumentsWithoutIndirectResults()) |
496 | | - setUsefulAndPropagateToOperands(arg, dependentVariableIndex); |
| 439 | + |
| 440 | + ValueWorklist worklist(dti->getResult(0)); |
| 441 | + |
| 442 | + while (SILValue v = worklist.pop()) { |
| 443 | + for (auto use : v->getUses()) { |
| 444 | + switch (use->getUser()->getKind()) { |
| 445 | + case SILInstructionKind::UncheckedRefCastInst: |
| 446 | + case SILInstructionKind::StructExtractInst: |
| 447 | + case SILInstructionKind::BeginBorrowInst: |
| 448 | + worklist.pushIfNotVisited(cast<SingleValueInstruction>(use->getUser())); |
| 449 | + break; |
| 450 | + case SILInstructionKind::RefTailAddrInst: { |
| 451 | + auto *rta = cast<RefTailAddrInst>(use->getUser()); |
| 452 | + setUseful(rta, dependentVariableIndex); |
| 453 | + // Propagate usefulness through array element addresses: |
| 454 | + // `pointer_to_address` and `index_addr` instructions. |
| 455 | + // |
| 456 | + // - Set all array element addresses as useful. |
| 457 | + // - Find instructions with array element addresses as "result": |
| 458 | + // - `store` and `copy_addr` with array element address as destination. |
| 459 | + // - `apply` with array element address as an indirect result. |
| 460 | + // - For each instruction, propagate usefulness through "arguments": |
| 461 | + // - `store` and `copy_addr`: propagate to source. |
| 462 | + // - `apply`: propagate to arguments. |
| 463 | + // |
| 464 | + // NOTE: `propagateUseful(use->getUser(), ...)` is intentionally not used |
| 465 | + // because it marks more values than necessary as useful, including: |
| 466 | + // - The `RawPointer` result of the intrinsic. |
| 467 | + // - `integer_literal` operands to `index_addr` for indexing the |
| 468 | + // `RawPointer`. |
| 469 | + // It is also blocked by TF-1032: control flow differentiation crash for |
| 470 | + // active values with no tangent space. |
| 471 | + for (auto use : rta->getUses()) { |
| 472 | + auto *user = use->getUser(); |
| 473 | + if (auto *si = dyn_cast<StoreInst>(user)) { |
| 474 | + setUseful(si->getDest(), dependentVariableIndex); |
| 475 | + setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex); |
| 476 | + } else if (auto *cai = dyn_cast<CopyAddrInst>(user)) { |
| 477 | + setUseful(cai->getDest(), dependentVariableIndex); |
| 478 | + setUsefulAndPropagateToOperands(cai->getSrc(), |
| 479 | + dependentVariableIndex); |
| 480 | + } else if (auto *ai = dyn_cast<ApplyInst>(user)) { |
| 481 | + if (FullApplySite(ai).isIndirectResultOperand(*use)) |
| 482 | + for (auto arg : ai->getArgumentsWithoutIndirectResults()) |
| 483 | + setUsefulAndPropagateToOperands(arg, dependentVariableIndex); |
| 484 | + } else if (auto *iai = dyn_cast<IndexAddrInst>(user)) { |
| 485 | + setUseful(iai, dependentVariableIndex); |
| 486 | + for (auto use : iai->getUses()) { |
| 487 | + auto *user = use->getUser(); |
| 488 | + if (auto si = dyn_cast<StoreInst>(user)) { |
| 489 | + setUseful(si->getDest(), dependentVariableIndex); |
| 490 | + setUsefulAndPropagateToOperands(si->getSrc(), |
| 491 | + dependentVariableIndex); |
| 492 | + } else if (auto *cai = dyn_cast<CopyAddrInst>(user)) { |
| 493 | + setUseful(cai->getDest(), dependentVariableIndex); |
| 494 | + setUsefulAndPropagateToOperands(cai->getSrc(), |
| 495 | + dependentVariableIndex); |
| 496 | + } else if (auto *ai = dyn_cast<ApplyInst>(user)) { |
| 497 | + if (FullApplySite(ai).isIndirectResultOperand(*use)) |
| 498 | + for (auto arg : ai->getArgumentsWithoutIndirectResults()) |
| 499 | + setUsefulAndPropagateToOperands(arg, dependentVariableIndex); |
| 500 | + } |
| 501 | + } |
| 502 | + } |
497 | 503 | } |
| 504 | + break; |
498 | 505 | } |
| 506 | + default: |
| 507 | + break; |
499 | 508 | } |
500 | 509 | } |
501 | 510 | } |
|
0 commit comments