@@ -4410,20 +4410,37 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
44104410 }
44114411};
44124412
4413- enum class ConditionType { NIL, NOT_NIL };
4413+ // / The type of a condition in a conditional statement.
4414+ enum class ConditionType {
4415+ NIL, // == nil
4416+ NOT_NIL, // != nil
4417+ SUCCESS_PATTERN, // case .success
4418+ FAILURE_PATTEN // case .failure
4419+ };
4420+
4421+ // / Indicates whether a condition describes a success or failure path. For
4422+ // / example, a check for whether an error parameter is present is a failure
4423+ // / path. A check for a nil error parameter is a success path. This is distinct
4424+ // / from ConditionType, as it relies on contextual information about what values
4425+ // / need to be checked for success or failure.
4426+ enum class ConditionPath { SUCCESS, FAILURE };
4427+
4428+ static ConditionPath flippedConditionPath (ConditionPath Path) {
4429+ switch (Path) {
4430+ case ConditionPath::SUCCESS:
4431+ return ConditionPath::FAILURE;
4432+ case ConditionPath::FAILURE:
4433+ return ConditionPath::SUCCESS;
4434+ }
4435+ llvm_unreachable (" Unhandled case in switch!" );
4436+ }
44144437
44154438// / Finds the `Subject` being compared to in various conditions. Also finds any
44164439// / pattern that may have a bound name.
44174440struct CallbackCondition {
44184441 Optional<ConditionType> Type;
44194442 const Decl *Subject = nullptr ;
44204443 const Pattern *BindPattern = nullptr ;
4421- // Bit of a hack. When the `Subject` is a `Result` type we use this to
4422- // distinguish between the `.success` and `.failure` case (as opposed to just
4423- // checking whether `Subject` == `TheErrDecl`)
4424- bool ErrorCase = false ;
4425-
4426- CallbackCondition () = default ;
44274444
44284445 // / Initializes a `CallbackCondition` with a `!=` or `==` comparison of
44294446 // / an `Optional` typed `Subject` to `nil`, ie.
@@ -4489,65 +4506,17 @@ struct CallbackCondition {
44894506
44904507 bool isValid () const { return Type.hasValue (); }
44914508
4492- // / Given an `if` condition `Cond` and a set of `Decls`, find any
4493- // / `CallbackCondition`s in `Cond` that use one of those `Decls` and add them
4494- // / to the map `AddTo`. Return `true` if all elements in the condition are
4495- // / "handled", ie. every condition can be mapped to a single `Decl` in
4496- // / `Decls`.
4497- static bool all (StmtCondition Cond, llvm::DenseSet<const Decl *> Decls,
4498- llvm::DenseMap<const Decl *, CallbackCondition> &AddTo) {
4499- bool Handled = true ;
4500- for (auto &CondElement : Cond) {
4501- if (auto *BoolExpr = CondElement.getBooleanOrNull ()) {
4502- SmallVector<Expr *, 1 > Exprs;
4503- Exprs.push_back (BoolExpr);
4504-
4505- while (!Exprs.empty ()) {
4506- auto *Next = Exprs.pop_back_val ();
4507- if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4508- Next = ACE->getSingleExpressionBody ();
4509-
4510- if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4511- auto *Operator = isOperator (BE);
4512- if (Operator) {
4513- if (Operator->getBaseName () == " &&" ) {
4514- Exprs.push_back (BE->getLHS ());
4515- Exprs.push_back (BE->getRHS ());
4516- } else {
4517- addCond (CallbackCondition (BE, Operator), Decls, AddTo, Handled);
4518- }
4519- continue ;
4520- }
4521- }
4522-
4523- Handled = false ;
4524- }
4525- } else if (auto *P = CondElement.getPatternOrNull ()) {
4526- addCond (CallbackCondition (P, CondElement.getInitializer ()), Decls,
4527- AddTo, Handled);
4528- }
4529- }
4530- return Handled && !AddTo.empty ();
4531- }
4532-
45334509private:
4534- static void addCond (const CallbackCondition &CC,
4535- llvm::DenseSet<const Decl *> Decls,
4536- llvm::DenseMap<const Decl *, CallbackCondition> &AddTo,
4537- bool &Handled) {
4538- if (!CC.isValid () || !Decls.count (CC.Subject ) ||
4539- !AddTo.try_emplace (CC.Subject , CC).second )
4540- Handled = false ;
4541- }
4542-
45434510 void initFromEnumPattern (const Decl *D, const EnumElementPattern *EEP) {
45444511 if (auto *EED = EEP->getElementDecl ()) {
45454512 auto eedTy = EED->getParentEnum ()->getDeclaredType ();
45464513 if (!eedTy || !eedTy->isResult ())
45474514 return ;
4548- if (EED->getNameStr () == StringRef (" failure" ))
4549- ErrorCase = true ;
4550- Type = ConditionType::NOT_NIL;
4515+ if (EED->getNameStr () == StringRef (" failure" )) {
4516+ Type = ConditionType::FAILURE_PATTEN;
4517+ } else {
4518+ Type = ConditionType::SUCCESS_PATTERN;
4519+ }
45514520 Subject = D;
45524521 BindPattern = EEP->getSubPattern ();
45534522 }
@@ -4581,6 +4550,27 @@ struct CallbackCondition {
45814550 }
45824551};
45834552
4553+ // / A CallbackCondition with additional semantic information about whether it
4554+ // / is for a success path or failure path.
4555+ struct ClassifiedCondition : public CallbackCondition {
4556+ ConditionPath Path;
4557+
4558+ explicit ClassifiedCondition (CallbackCondition Cond, ConditionPath Path)
4559+ : CallbackCondition(Cond), Path(Path) {}
4560+ };
4561+
4562+ // / A wrapper for a map of parameter decls to their classified conditions, or
4563+ // / \c None if they are not present in any conditions.
4564+ struct ClassifiedCallbackConditions final
4565+ : llvm::MapVector<const Decl *, ClassifiedCondition> {
4566+ Optional<ClassifiedCondition> lookup (const Decl *D) const {
4567+ auto Res = find (D);
4568+ if (Res == end ())
4569+ return None;
4570+ return Res->second ;
4571+ }
4572+ };
4573+
45844574// / A list of nodes to print, along with a list of locations that may have
45854575// / preceding comments attached, which also need printing. For example:
45864576// /
@@ -4726,7 +4716,7 @@ class ClassifiedBlock {
47264716 Nodes.addNode (Node);
47274717 }
47284718
4729- void addBinding (const CallbackCondition &FromCondition,
4719+ void addBinding (const ClassifiedCondition &FromCondition,
47304720 DiagnosticEngine &DiagEngine) {
47314721 if (!FromCondition.BindPattern )
47324722 return ;
@@ -4750,9 +4740,8 @@ class ClassifiedBlock {
47504740 BoundNames.try_emplace (FromCondition.Subject , Name);
47514741 }
47524742
4753- void addAllBindings (
4754- const llvm::DenseMap<const Decl *, CallbackCondition> &FromConditions,
4755- DiagnosticEngine &DiagEngine) {
4743+ void addAllBindings (const ClassifiedCallbackConditions &FromConditions,
4744+ DiagnosticEngine &DiagEngine) {
47564745 for (auto &Entry : FromConditions) {
47574746 addBinding (Entry.second , DiagEngine);
47584747 if (DiagEngine.hadAnyError ())
@@ -4841,12 +4830,108 @@ struct CallbackClassifier {
48414830 CurrentBlock->addPossibleCommentLoc (endCommentLoc);
48424831 }
48434832
4833+ // / Given a callback condition, classify it as a success or failure path, or
4834+ // / \c None if it cannot be classified.
4835+ Optional<ClassifiedCondition>
4836+ classifyCallbackCondition (const CallbackCondition &Cond) {
4837+ if (!Cond.isValid ())
4838+ return None;
4839+
4840+ // For certain types of condition, they need to appear in certain lists.
4841+ auto CondType = *Cond.Type ;
4842+ switch (CondType) {
4843+ case ConditionType::NOT_NIL:
4844+ case ConditionType::NIL:
4845+ if (!UnwrapParams.count (Cond.Subject ))
4846+ return None;
4847+ break ;
4848+ case ConditionType::SUCCESS_PATTERN:
4849+ case ConditionType::FAILURE_PATTEN:
4850+ if (!IsResultParam || Cond.Subject != ErrParam)
4851+ return None;
4852+ break ;
4853+ }
4854+
4855+ // Let's start with a success path, and flip any negative conditions.
4856+ auto Path = ConditionPath::SUCCESS;
4857+
4858+ // If it's an error param, that's a flip.
4859+ if (Cond.Subject == ErrParam && !IsResultParam)
4860+ Path = flippedConditionPath (Path);
4861+
4862+ // If we have a nil or failure condition, that's a flip.
4863+ switch (CondType) {
4864+ case ConditionType::NIL:
4865+ case ConditionType::FAILURE_PATTEN:
4866+ Path = flippedConditionPath (Path);
4867+ break ;
4868+ case ConditionType::NOT_NIL:
4869+ case ConditionType::SUCCESS_PATTERN:
4870+ break ;
4871+ }
4872+ return ClassifiedCondition (Cond, Path);
4873+ }
4874+
4875+ // / Classifies all the conditions present in a given StmtCondition. Returns
4876+ // / \c true if there were any conditions that couldn't be classified,
4877+ // / \c false otherwise.
4878+ bool classifyConditionsOf (StmtCondition Cond,
4879+ ClassifiedCallbackConditions &Conditions) {
4880+ bool UnhandledConditions = false ;
4881+ auto TryAddCond = [&](CallbackCondition CC) {
4882+ auto Classified = classifyCallbackCondition (CC);
4883+ if (!Classified) {
4884+ UnhandledConditions = true ;
4885+ return ;
4886+ }
4887+ // If we've seen multiple conditions for the same subject, don't handle
4888+ // this.
4889+ if (!Conditions.insert ({CC.Subject , *Classified}).second )
4890+ UnhandledConditions = true ;
4891+ };
4892+
4893+ for (auto &CondElement : Cond) {
4894+ if (auto *BoolExpr = CondElement.getBooleanOrNull ()) {
4895+ SmallVector<Expr *, 1 > Exprs;
4896+ Exprs.push_back (BoolExpr);
4897+
4898+ while (!Exprs.empty ()) {
4899+ auto *Next = Exprs.pop_back_val ();
4900+ if (auto *ACE = dyn_cast<AutoClosureExpr>(Next))
4901+ Next = ACE->getSingleExpressionBody ();
4902+
4903+ if (auto *BE = dyn_cast_or_null<BinaryExpr>(Next)) {
4904+ auto *Operator = isOperator (BE);
4905+ if (Operator) {
4906+ // If we have an && operator, decompose its arguments.
4907+ if (Operator->getBaseName () == " &&" ) {
4908+ Exprs.push_back (BE->getLHS ());
4909+ Exprs.push_back (BE->getRHS ());
4910+ } else {
4911+ // Otherwise check to see if we have an == nil or != nil
4912+ // condition.
4913+ TryAddCond (CallbackCondition (BE, Operator));
4914+ }
4915+ continue ;
4916+ }
4917+ }
4918+ UnhandledConditions = true ;
4919+ }
4920+ } else if (auto *P = CondElement.getPatternOrNull ()) {
4921+ TryAddCond (CallbackCondition (P, CondElement.getInitializer ()));
4922+ }
4923+ }
4924+ return UnhandledConditions || Conditions.empty ();
4925+ }
4926+
4927+ // / Classifies the conditions of a conditional statement, and adds the
4928+ // / necessary nodes to either the success or failure block.
48444929 void classifyConditional (Stmt *Statement, StmtCondition Condition,
48454930 NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) {
4846- llvm::DenseMap< const Decl *, CallbackCondition> CallbackConditions;
4847- bool UnhandledConditions =
4848- ! CallbackCondition::all ( Condition, UnwrapParams , CallbackConditions);
4849- CallbackCondition ErrCondition = CallbackConditions.lookup (ErrParam);
4931+ ClassifiedCallbackConditions CallbackConditions;
4932+ bool UnhandledConditions = classifyConditionsOf (
4933+ Condition, CallbackConditions);
4934+ auto ErrCondition = CallbackConditions.lookup (ErrParam);
48504935
48514936 if (UnhandledConditions) {
48524937 // Some unknown conditions. If there's an else, assume we can't handle
@@ -4862,12 +4947,11 @@ struct CallbackClassifier {
48624947 } else if (ElseStmt) {
48634948 DiagEngine.diagnose (Statement->getStartLoc (),
48644949 diag::unknown_callback_conditions);
4865- } else if (ErrCondition.isValid () &&
4866- ErrCondition.Type == ConditionType::NOT_NIL) {
4950+ } else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) {
48674951 Blocks.ErrorBlock .addNode (Statement);
48684952 } else {
48694953 for (auto &Entry : CallbackConditions) {
4870- if (Entry.second .Type == ConditionType::NIL ) {
4954+ if (Entry.second .Path == ConditionPath::FAILURE ) {
48714955 Blocks.ErrorBlock .addNode (Statement);
48724956 return ;
48734957 }
@@ -4877,42 +4961,36 @@ struct CallbackClassifier {
48774961 return ;
48784962 }
48794963
4880- ClassifiedBlock *ThenBlock = &Blocks.SuccessBlock ;
4881- ClassifiedBlock *ElseBlock = &Blocks.ErrorBlock ;
4882-
4883- if (ErrCondition.isValid () && (!IsResultParam || ErrCondition.ErrorCase ) &&
4884- ErrCondition.Type == ConditionType::NOT_NIL) {
4885- ClassifiedBlock *TempBlock = ThenBlock;
4886- ThenBlock = ElseBlock;
4887- ElseBlock = TempBlock;
4888- } else {
4889- Optional<ConditionType> CondType;
4890- for (auto &Entry : CallbackConditions) {
4891- if (IsResultParam || Entry.second .Subject != ErrParam) {
4892- if (!CondType) {
4893- CondType = Entry.second .Type ;
4894- } else if (CondType != Entry.second .Type ) {
4895- // Similar to the unknown conditions case. Add the whole if unless
4896- // there's an else, in which case use the fallback instead.
4897- // TODO: Split the `if` statement
4898-
4899- if (ElseStmt) {
4900- DiagEngine.diagnose (Statement->getStartLoc (),
4901- diag::mixed_callback_conditions);
4902- } else {
4903- CurrentBlock->addNode (Statement);
4904- }
4905- return ;
4906- }
4964+ // If all the conditions were classified, make sure they're all consistently
4965+ // on the success or failure path.
4966+ Optional<ConditionPath> Path;
4967+ for (auto &Entry : CallbackConditions) {
4968+ auto &Cond = Entry.second ;
4969+ if (!Path) {
4970+ Path = Cond.Path ;
4971+ } else if (*Path != Cond.Path ) {
4972+ // Similar to the unknown conditions case. Add the whole if unless
4973+ // there's an else, in which case use the fallback instead.
4974+ // TODO: Split the `if` statement
4975+
4976+ if (ElseStmt) {
4977+ DiagEngine.diagnose (Statement->getStartLoc (),
4978+ diag::mixed_callback_conditions);
4979+ } else {
4980+ CurrentBlock->addNode (Statement);
49074981 }
4908- }
4909-
4910- if (CondType == ConditionType::NIL) {
4911- ClassifiedBlock *TempBlock = ThenBlock;
4912- ThenBlock = ElseBlock;
4913- ElseBlock = TempBlock;
4982+ return ;
49144983 }
49154984 }
4985+ assert (Path && " Didn't classify a path?" );
4986+
4987+ auto *ThenBlock = &Blocks.SuccessBlock ;
4988+ auto *ElseBlock = &Blocks.ErrorBlock ;
4989+
4990+ // If the condition is for a failure path, the error block is ThenBlock, and
4991+ // the success block is ElseBlock.
4992+ if (*Path == ConditionPath::FAILURE)
4993+ std::swap (ThenBlock, ElseBlock);
49164994
49174995 // We'll be dropping the statement, but make sure to keep any attached
49184996 // comments.
@@ -4983,13 +5061,15 @@ struct CallbackClassifier {
49835061 return ;
49845062 }
49855063
4986- CallbackCondition CC (ErrParam, &Items[0 ]);
4987- ClassifiedBlock *Block = &Blocks.SuccessBlock ;
4988- ClassifiedBlock *OtherBlock = &Blocks.ErrorBlock ;
4989- if (CC.ErrorCase ) {
4990- Block = &Blocks.ErrorBlock ;
4991- OtherBlock = &Blocks.SuccessBlock ;
4992- }
5064+ auto *Block = &Blocks.SuccessBlock ;
5065+ auto *OtherBlock = &Blocks.ErrorBlock ;
5066+ auto SuccessNodes = NodesToPrint::inBraceStmt (CS->getBody ());
5067+
5068+ // Classify the case pattern.
5069+ auto CC = classifyCallbackCondition (
5070+ CallbackCondition (ErrParam, &Items[0 ]));
5071+ if (CC && CC->Path == ConditionPath::FAILURE)
5072+ std::swap (Block, OtherBlock);
49935073
49945074 // We'll be dropping the case, but make sure to keep any attached
49955075 // comments. Because these comments will effectively be part of the
@@ -5000,8 +5080,9 @@ struct CallbackClassifier {
50005080 if (CS == Cases.back ())
50015081 Block->addPossibleCommentLoc (SS->getRBraceLoc ());
50025082
5003- setNodes (Block, OtherBlock, NodesToPrint::inBraceStmt (CS->getBody ()));
5004- Block->addBinding (CC, DiagEngine);
5083+ setNodes (Block, OtherBlock, std::move (SuccessNodes));
5084+ if (CC)
5085+ Block->addBinding (*CC, DiagEngine);
50055086 if (DiagEngine.hadAnyError ())
50065087 return ;
50075088 }
0 commit comments