@@ -232,14 +232,17 @@ class PartitionOpTranslator {
232232 // raw pointers to ensure it is treated as non-Sendable and strict checking
233233 // is applied to it
234234 bool isNonSendableType (SILType type) const {
235- if (type.getASTType ()->getKind () == TypeKind::BuiltinNativeObject) {
235+ switch (type.getASTType ()->getKind ()) {
236+ case TypeKind::BuiltinNativeObject:
237+ case TypeKind::BuiltinRawPointer:
236238 // these are very unsafe... definitely not Sendable
237239 return true ;
240+ default :
241+ // consider caching this if it's a performance bottleneck
242+ return TypeChecker::conformsToProtocol (
243+ type.getASTType (), sendableProtocol, function->getParentModule ())
244+ .hasMissingConformance (function->getParentModule ());
238245 }
239- // consider caching this if it's a bottleneck
240- return TypeChecker::conformsToProtocol (
241- type.getASTType (), sendableProtocol, function->getParentModule ())
242- .hasMissingConformance (function->getParentModule ());
243246 }
244247
245248 // used to statefully track the instruction currently being translated,
@@ -403,16 +406,29 @@ class PartitionOpTranslator {
403406 applyInst->getOperandValues ().end ()}
404407 );
405408
409+ return translateIsolationCrossingSILApply (applyInst);
410+ }
411+
412+ // handles the semantics for SIL applies that cross isolation
413+ // in particular, all arguments are consumed
414+ std::vector<PartitionOp> translateIsolationCrossingSILApply (
415+ const SILInstruction *applyInst) {
406416 ApplyExpr *sourceApply = applyInst->getLoc ().getAsASTNode <ApplyExpr>();
407417 assert (sourceApply && " only ApplyExpr's should cross isolation domains" );
408418
409419 std::vector<PartitionOp> translated;
420+
421+ // require all operands
422+ for (auto op : applyInst->getOperandValues ())
423+ if (auto trackOp = trackIfNonSendable (op))
424+ translated.push_back (Require (trackOp.value ()).front ());
425+
410426 auto getSourceArg = [&](unsigned i) {
411- if (i < sourceApply->getArgs ()->size ())
412- return sourceApply->getArgs ()->getExpr (i);
427+ if (i < sourceApply->getArgs ()->size ())
428+ return sourceApply->getArgs ()->getExpr (i);
413429 assert (false && " SIL instruction has too many arguments for"
414430 " corresponding AST node" );
415- return (Expr *)nullptr ;
431+ return (Expr *)nullptr ;
416432 };
417433
418434 auto getSourceSelf = [&]() {
@@ -441,11 +457,11 @@ class PartitionOpTranslator {
441457 if (auto applyInstCast = dyn_cast<ApplyInst>(applyInst)) {
442458 handleSILOperands (applyInstCast->getArgumentsWithoutSelf ());
443459 if (applyInstCast->hasSelfArgument ())
444- handleSILSelf (applyInstCast->getSelfArgument ());
460+ handleSILSelf (applyInstCast->getSelfArgument ());
445461 } else if (auto applyInstCase = dyn_cast<TryApplyInst>(applyInst)) {
446462 handleSILOperands (applyInstCast->getArgumentsWithoutSelf ());
447463 if (applyInstCast->hasSelfArgument ())
448- handleSILSelf (applyInstCast->getSelfArgument ());
464+ handleSILSelf (applyInstCast->getSelfArgument ());
449465 } else {
450466 llvm_unreachable (" this instruction crossing isolation is not handled yet" );
451467 }
@@ -501,15 +517,15 @@ class PartitionOpTranslator {
501517
502518 // used to index the translations of SILInstructions performed
503519 // for refrence and debugging
504- int translationIndex = 0 ;
520+ static inline int translationIndex = 0 ;
505521
506522 // Some SILInstructions contribute to the partition of non-Sendable values
507523 // being analyzed. translateSILInstruction translate a SILInstruction
508524 // to its effect on the non-Sendable partition, if it has one.
509525 //
510526 // The current pattern of
511527 std::vector<PartitionOp> translateSILInstruction (SILInstruction *instruction) {
512- translationIndex++;
528+ LLVM_DEBUG ( translationIndex++;) ;
513529 currentInstruction = instruction;
514530
515531 // The following instructions are treated as assigning their result to a
@@ -910,11 +926,11 @@ class ConsumeRequireAccumulator {
910926 // that access ("require") the region consumed. Sorting is by lowest distance
911927 // first, then arbitrarily. This is used for final diagnostic output.
912928 void forEachConsumeRequire (
913- unsigned numRequiresPerConsume,
914929 llvm::function_ref<void (const PartitionOp& consumeOp, unsigned numProcessed, unsigned numSkipped)>
915930 processConsumeOp,
916931 llvm::function_ref<void(const PartitionOp& requireOp)>
917- processRequireOp) const {
932+ processRequireOp,
933+ unsigned numRequiresPerConsume = UINT_MAX) const {
918934 for (auto [consumeOp, requireOps] : requirementsForConsumptions) {
919935 unsigned numProcessed = std::min ({(unsigned ) requireOps.size (),
920936 (unsigned ) numRequiresPerConsume});
@@ -927,6 +943,19 @@ class ConsumeRequireAccumulator {
927943 }
928944 }
929945 }
946+
947+ void dump () const {
948+ forEachConsumeRequire (
949+ [](const PartitionOp& consumeOp, unsigned numProcessed, unsigned numSkipped) {
950+ llvm::dbgs () << " ┌──╼ CONSUME: " ;
951+ consumeOp.dump ();
952+ },
953+ [](const PartitionOp& requireOp) {
954+ llvm::dbgs () << " ├╼ REQUIRE: " ;
955+ requireOp.dump ();
956+ }
957+ );
958+ }
930959};
931960
932961// A RaceTracer is used to accumulate the facts that the main phase of
@@ -1075,8 +1104,9 @@ class RaceTracer {
10751104 if (workingPartition.isConsumed (consumedVal))
10761105 workingPartition.apply (PartitionOp::AssignFresh (consumedVal));
10771106
1107+ int i = 0 ;
10781108 block.forEachPartitionOp ([&](const PartitionOp& partitionOp) {
1079- if (targetOp && targetOp == partitionOp)
1109+ if (targetOp == partitionOp)
10801110 return false ; // break
10811111 workingPartition.apply (partitionOp);
10821112 if (workingPartition.isConsumed (consumedVal) && !consumedReason) {
@@ -1093,6 +1123,7 @@ class RaceTracer {
10931123 consumedReason = llvm::None;
10941124
10951125 // continue walking block
1126+ i++;
10961127 return true ;
10971128 });
10981129
@@ -1102,7 +1133,10 @@ class RaceTracer {
11021133 consumedReason = LocalConsumedReason::NonLocal ();
11031134
11041135 // if consumedReason is none, then consumedVal was not actually consumed
1105- assert (consumedReason);
1136+ assert (consumedReason
1137+ || dumpBlockSearch (SILBlock, consumedVal)
1138+ && " no consumption was found"
1139+ );
11061140
11071141 // if this is a query for consumption reason at block exit, update the cache
11081142 if (!targetOp)
@@ -1112,6 +1146,31 @@ class RaceTracer {
11121146 return consumedReason.value ();
11131147 }
11141148
1149+ bool dumpBlockSearch (SILBasicBlock * SILBlock, TrackableValueID consumedVal) {
1150+ LLVM_DEBUG (
1151+ unsigned i = 0 ;
1152+ const BlockPartitionState &block = blockStates[SILBlock];
1153+ Partition working = block.getEntryPartition ();
1154+ llvm::dbgs () << " ┌──────────╼\n │ " ;
1155+ working.dump ();
1156+ block.forEachPartitionOp ([&](const PartitionOp &op) {
1157+ llvm::dbgs () << " ├[" << i++ << " ] " ;
1158+ op.dump ();
1159+ working.apply (op);
1160+ llvm::dbgs () << " │ " ;
1161+ if (working.isConsumed (consumedVal)) {
1162+ llvm::errs ().changeColor (llvm::raw_ostream::RED, true );
1163+ llvm::errs () << " (" << consumedVal << " CONSUMED) " ;
1164+ llvm::errs ().resetColor ();
1165+ }
1166+ working.dump ();
1167+ return true ;
1168+ });
1169+ llvm::dbgs () << " └──────────╼\n " ;
1170+ );
1171+ return false ;
1172+ }
1173+
11151174public:
11161175 RaceTracer (const BasicBlockData<BlockPartitionState>& blockStates)
11171176 : blockStates(blockStates) {}
@@ -1126,7 +1185,6 @@ class RaceTracer {
11261185 }
11271186};
11281187
1129-
11301188// Instances of PartitionAnalysis perform the region-based Sendable checking.
11311189// Internally, a PartitionOpTranslator is stored to perform the translation from
11321190// SILInstructions to PartitionOps, then a fixed point iteration is run to
@@ -1258,10 +1316,14 @@ class PartitionAnalysis {
12581316 void diagnose () {
12591317 assert (solved && " diagnose should not be called before solve" );
12601318
1261- // llvm::dbgs() << function->getName() << "\n";
1319+ LLVM_DEBUG (
1320+ llvm::dbgs () << " Emitting diagnostics for function "
1321+ << function->getName () << " \n " );
12621322 RaceTracer tracer = blockStates;
12631323
12641324 for (auto [_, blockState] : blockStates) {
1325+ // populate the raceTracer with all requires of consumed valued found
1326+ // throughout the CFG
12651327 blockState.diagnoseFailures (
12661328 /* handleFailure=*/
12671329 [&](const PartitionOp& partitionOp, TrackableValueID consumedVal) {
@@ -1282,16 +1344,23 @@ class PartitionAnalysis {
12821344 });
12831345 }
12841346
1347+ LLVM_DEBUG (
1348+ llvm::dbgs () << " Accumulator Complete:\n " ;
1349+ raceTracer.getAccumulator ().dump ();
1350+ );
1351+
1352+ // ask the raceTracer to report diagnostics at the consumption sites
1353+ // for all the racy requirement sites entered into it above
12851354 raceTracer.getAccumulator ().forEachConsumeRequire (
1286- NUM_REQUIREMENTS_TO_DIAGNOSE,
12871355 /* diagnoseConsume=*/
12881356 [&](const PartitionOp& consumeOp,
12891357 unsigned numDisplayed, unsigned numHidden) {
12901358
12911359 if (tryDiagnoseAsCallSite (consumeOp, numDisplayed, numHidden))
12921360 return ;
12931361
1294- assert (false );
1362+ assert (false && " no consumptions besides callsites implemented yet" );
1363+
12951364 // default to more generic diagnostic
12961365 auto expr = getExprForPartitionOp (consumeOp);
12971366 auto diag = function->getASTContext ().Diags .diagnose (
@@ -1307,7 +1376,8 @@ class PartitionAnalysis {
13071376 function->getASTContext ().Diags .diagnose (
13081377 expr->getLoc (), diag::possible_racy_access_site)
13091378 .highlight (expr->getSourceRange ());
1310- });
1379+ },
1380+ NUM_REQUIREMENTS_TO_DIAGNOSE);
13111381 }
13121382
13131383 // try to interpret this consumeOp as a source-level callsite (ApplyExpr),
@@ -1331,7 +1401,7 @@ class PartitionAnalysis {
13311401 assert (false && " sourceExpr should be populated for ApplyExpr consumptions" );
13321402
13331403 function->getASTContext ().Diags .diagnose (
1334- apply ->getLoc (), diag::call_site_consumption_yields_race,
1404+ argExpr ->getLoc (), diag::call_site_consumption_yields_race,
13351405 findOriginalValueType (argExpr),
13361406 isolationCrossing.value ().getCallerIsolation (),
13371407 isolationCrossing.value ().getCalleeIsolation (),
0 commit comments