-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[refactoring] idea: convert a switch statement to an if #34352 FIXED #167142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[refactoring] idea: convert a switch statement to an if #34352 FIXED #167142
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-clang Author: None (krishnateja2314) ChangesFull diff: https://github.com/llvm/llvm-project/pull/167142.diff 8 Files Affected:
diff --git a/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h b/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h
new file mode 100644
index 0000000000000..aa46e7a9b52bd
--- /dev/null
+++ b/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h
@@ -0,0 +1,46 @@
+//===--- SwitchToIf.h - Switch to if refactoring -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H
+#define LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H
+
+#include "clang/Tooling/Refactoring/ASTSelection.h"
+#include "clang/Tooling/Refactoring/RefactoringActionRules.h"
+
+namespace clang {
+class SwitchStmt;
+
+namespace tooling {
+
+/// A "Switch to If" refactoring converts a switch statement into an if-else
+/// chain.
+class SwitchToIf final : public SourceChangeRefactoringRule {
+public:
+ /// Initiates the switch-to-if refactoring operation.
+ ///
+ /// \param Selection The selected AST node, which should be a switch statement.
+ static Expected<SwitchToIf>
+ initiate(RefactoringRuleContext &Context,
+ SelectedASTNode Selection);
+
+ static const RefactoringDescriptor &describe();
+
+private:
+ SwitchToIf(const SwitchStmt *Switch) : TheSwitch(Switch) {}
+
+ Expected<AtomicChanges>
+ createSourceReplacements(RefactoringRuleContext &Context) override;
+
+ const SwitchStmt *TheSwitch;
+};
+
+} // end namespace tooling
+} // end namespace clang
+
+#endif // LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H
+
diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt
index d3077be8810aa..35b806177fb7e 100644
--- a/clang/lib/Tooling/Refactoring/CMakeLists.txt
+++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt
@@ -13,6 +13,7 @@ add_clang_library(clangToolingRefactoring
Rename/USRFinder.cpp
Rename/USRFindingAction.cpp
Rename/USRLocFinder.cpp
+ SwitchToIf/SwitchToIf.cpp
LINK_LIBS
clangAST
diff --git a/clang/lib/Tooling/Refactoring/RefactoringActions.cpp b/clang/lib/Tooling/Refactoring/RefactoringActions.cpp
index bf98941f568b3..1f6f3f641aa8c 100644
--- a/clang/lib/Tooling/Refactoring/RefactoringActions.cpp
+++ b/clang/lib/Tooling/Refactoring/RefactoringActions.cpp
@@ -10,6 +10,7 @@
#include "clang/Tooling/Refactoring/RefactoringAction.h"
#include "clang/Tooling/Refactoring/RefactoringOptions.h"
#include "clang/Tooling/Refactoring/Rename/RenamingAction.h"
+#include "clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h"
namespace clang {
namespace tooling {
@@ -93,6 +94,24 @@ class LocalRename final : public RefactoringAction {
}
};
+class SwitchToIfRefactoring final : public RefactoringAction {
+public:
+ StringRef getCommand() const override { return "switch-to-if"; }
+
+ StringRef getDescription() const override {
+ return "Converts a switch statement into an if-else chain";
+ }
+
+ /// Returns a set of refactoring actions rules that are defined by this
+ /// action.
+ RefactoringActionRules createActionRules() const override {
+ RefactoringActionRules Rules;
+ Rules.push_back(createRefactoringActionRule<SwitchToIf>(
+ ASTSelectionRequirement()));
+ return Rules;
+ }
+};
+
} // end anonymous namespace
std::vector<std::unique_ptr<RefactoringAction>> createRefactoringActions() {
@@ -100,6 +119,7 @@ std::vector<std::unique_ptr<RefactoringAction>> createRefactoringActions() {
Actions.push_back(std::make_unique<LocalRename>());
Actions.push_back(std::make_unique<ExtractRefactoring>());
+ Actions.push_back(std::make_unique<SwitchToIfRefactoring>());
return Actions;
}
diff --git a/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp b/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp
new file mode 100644
index 0000000000000..cc6d30dc6d4ce
--- /dev/null
+++ b/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp
@@ -0,0 +1,316 @@
+//===--- SwitchToIf.cpp - Switch to if refactoring ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/Stmt.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Tooling/Refactoring/AtomicChange.h"
+#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
+#include "clang/Tooling/Refactoring/RefactoringRuleContext.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/raw_ostream.h"
+#include <functional>
+#include <optional>
+
+using namespace clang;
+using namespace tooling;
+
+namespace {
+
+/// Returns the source text for the given expression.
+std::string getExprText(const Expr *E, const SourceManager &SM,
+ const LangOptions &LangOpts) {
+ SourceRange Range = E->getSourceRange();
+ return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), SM,
+ LangOpts)
+ .str();
+}
+
+/// Returns the source text for a range.
+std::string getSourceText(SourceRange Range, const SourceManager &SM,
+ const LangOptions &LangOpts) {
+ return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), SM,
+ LangOpts)
+ .str();
+}
+
+/// Returns true if the statement is a break statement.
+bool isBreakStmt(const Stmt *S) {
+ return isa<BreakStmt>(S);
+}
+
+/// Recursively collects all statements from a statement, removing breaks.
+/// This handles compound statements and stops at the first break.
+void collectStatementsWithoutBreaks(const Stmt *S,
+ SmallVector<const Stmt *, 16> &Result,
+ const SourceManager &SM,
+ const LangOptions &LangOpts) {
+ if (!S)
+ return;
+
+ if (isBreakStmt(S))
+ return;
+
+ if (const CompoundStmt *CS = dyn_cast<CompoundStmt>(S)) {
+ // Process each statement in the compound statement
+ for (const Stmt *Child : CS->body()) {
+ if (isBreakStmt(Child)) {
+ // Stop at first break
+ break;
+ }
+ collectStatementsWithoutBreaks(Child, Result, SM, LangOpts);
+ }
+ } else {
+ // For non-compound statements, add them directly
+ Result.push_back(S);
+ }
+}
+
+/// Gets the statements from a case/default, removing breaks.
+SmallVector<const Stmt *, 16> getCaseStatements(const SwitchCase *SC,
+ const SourceManager &SM,
+ const LangOptions &LangOpts) {
+ SmallVector<const Stmt *, 16> Result;
+ const Stmt *SubStmt = SC->getSubStmt();
+ if (!SubStmt)
+ return Result;
+
+ collectStatementsWithoutBreaks(SubStmt, Result, SM, LangOpts);
+ return Result;
+}
+
+} // end anonymous namespace
+
+const RefactoringDescriptor &SwitchToIf::describe() {
+ static const RefactoringDescriptor Descriptor = {
+ "switch-to-if",
+ "Switch to If",
+ "Converts a switch statement into an if-else chain",
+ };
+ return Descriptor;
+}
+
+Expected<SwitchToIf>
+SwitchToIf::initiate(RefactoringRuleContext &Context,
+ SelectedASTNode Selection) {
+ // Find the SwitchStmt in the selection
+ const SwitchStmt *Switch = nullptr;
+
+ // Helper lambda to recursively search for SwitchStmt
+ std::function<const SwitchStmt *(const SelectedASTNode &)> findSwitch =
+ [&](const SelectedASTNode &Node) -> const SwitchStmt * {
+ if (const SwitchStmt *S = Node.Node.get<SwitchStmt>()) {
+ return S;
+ }
+ // Search in children
+ for (const SelectedASTNode &Child : Node.Children) {
+ if (const SwitchStmt *S = findSwitch(Child)) {
+ return S;
+ }
+ }
+ return nullptr;
+ };
+
+ Switch = findSwitch(Selection);
+
+ if (!Switch) {
+ return Context.createDiagnosticError(
+ Context.getSelectionRange().getBegin(),
+ diag::err_refactor_selection_invalid_ast);
+ }
+
+ // Validate that the switch has at least one case
+ if (!Switch->getSwitchCaseList()) {
+ return Context.createDiagnosticError(
+ Switch->getSwitchLoc(),
+ diag::err_refactor_selection_invalid_ast);
+ }
+
+ return SwitchToIf(Switch);
+}
+
+Expected<AtomicChanges>
+SwitchToIf::createSourceReplacements(RefactoringRuleContext &Context) {
+ ASTContext &AST = Context.getASTContext();
+ SourceManager &SM = AST.getSourceManager();
+ const LangOptions &LangOpts = AST.getLangOpts();
+
+ const SwitchStmt *Switch = TheSwitch;
+ const Expr *Cond = Switch->getCond();
+
+ // Get the full source range of the switch statement
+ SourceLocation StartLoc = Switch->getBeginLoc();
+ SourceLocation EndLoc = Switch->getEndLoc();
+
+ // Find the actual end location (closing brace)
+ if (const Stmt *Body = Switch->getBody()) {
+ EndLoc = Body->getEndLoc();
+ }
+
+ SourceRange SwitchRange(StartLoc, EndLoc);
+
+ // Build the if-else chain
+ std::string Replacement;
+ llvm::raw_string_ostream OS(Replacement);
+
+ std::string CondText = getExprText(Cond, SM, LangOpts);
+
+ // Handle init statement if present
+ if (Switch->getInit()) {
+ std::string InitText = getSourceText(Switch->getInit()->getSourceRange(),
+ SM, LangOpts);
+ OS << InitText << " ";
+ }
+
+ // Handle condition variable if present
+ if (Switch->getConditionVariableDeclStmt()) {
+ std::string VarText = getSourceText(
+ Switch->getConditionVariableDeclStmt()->getSourceRange(), SM, LangOpts);
+ OS << VarText << " ";
+ }
+
+ bool First = true;
+ const SwitchCase *DefaultCase = nullptr;
+ SmallVector<const SwitchCase *, 16> Cases;
+
+ // Collect all cases and find default
+ for (const SwitchCase *SC = Switch->getSwitchCaseList(); SC;
+ SC = SC->getNextSwitchCase()) {
+ if (isa<DefaultStmt>(SC)) {
+ DefaultCase = SC;
+ } else {
+ Cases.push_back(SC);
+ }
+ }
+
+ // Process cases
+ for (const SwitchCase *Case : Cases) {
+ if (First) {
+ OS << "if (";
+ First = false;
+ } else {
+ OS << " else if (";
+ }
+
+ const CaseStmt *CS = cast<CaseStmt>(Case);
+ const Expr *LHS = CS->getLHS();
+
+ // Handle GNU case ranges
+ if (CS->caseStmtIsGNURange()) {
+ const Expr *RHS = CS->getRHS();
+ std::string LHSText = getExprText(LHS, SM, LangOpts);
+ std::string RHSText = getExprText(RHS, SM, LangOpts);
+ OS << CondText << " >= " << LHSText << " && " << CondText << " <= "
+ << RHSText;
+ } else {
+ std::string CaseValue = getExprText(LHS, SM, LangOpts);
+ OS << CondText << " == " << CaseValue;
+ }
+
+ OS << ") {\n";
+
+ // Get statements from this case (without breaks)
+ SmallVector<const Stmt *, 16> Statements = getCaseStatements(Case, SM, LangOpts);
+
+ // Print statements
+ if (Statements.empty()) {
+ // Empty case - just add a blank line or comment
+ OS << " // empty case\n";
+ } else {
+ for (const Stmt *S : Statements) {
+ SourceRange StmtRange = S->getSourceRange();
+ std::string StmtText = getSourceText(StmtRange, SM, LangOpts);
+
+ // Indent the statement
+ OS << " " << StmtText;
+
+ // For compound statements, they already have their own braces
+ // For other statements, ensure proper termination
+ if (!isa<CompoundStmt>(S) && !isa<IfStmt>(S) && !isa<ForStmt>(S) &&
+ !isa<WhileStmt>(S) && !isa<SwitchStmt>(S) && !isa<DoStmt>(S) &&
+ !isa<BreakStmt>(S) && !isa<ReturnStmt>(S) && !isa<GotoStmt>(S)) {
+ // Check if statement already ends with semicolon by looking at the
+ // source text
+ if (!StmtText.empty() && StmtText.back() != ';') {
+ // Try to get the token after the statement
+ SourceLocation AfterEnd = Lexer::getLocForEndOfToken(
+ StmtRange.getEnd(), 0, SM, LangOpts);
+ Token Tok;
+ if (Lexer::getRawToken(AfterEnd, Tok, SM, LangOpts, false) ||
+ !Tok.is(tok::semi)) {
+ OS << ";";
+ }
+ }
+ }
+ OS << "\n";
+ }
+ }
+
+ OS << "}";
+ }
+
+ // Process default case
+ if (DefaultCase) {
+ if (First) {
+ OS << "if (1) { // default case\n";
+ First = false;
+ } else {
+ OS << " else { // default case\n";
+ }
+
+ SmallVector<const Stmt *, 16> Statements = getCaseStatements(DefaultCase, SM, LangOpts);
+
+ if (Statements.empty()) {
+ OS << " // empty default case\n";
+ } else {
+ for (const Stmt *S : Statements) {
+ SourceRange StmtRange = S->getSourceRange();
+ std::string StmtText = getSourceText(StmtRange, SM, LangOpts);
+
+ OS << " " << StmtText;
+
+ if (!isa<CompoundStmt>(S) && !isa<IfStmt>(S) && !isa<ForStmt>(S) &&
+ !isa<WhileStmt>(S) && !isa<SwitchStmt>(S) && !isa<DoStmt>(S) &&
+ !isa<BreakStmt>(S) && !isa<ReturnStmt>(S) && !isa<GotoStmt>(S)) {
+ if (!StmtText.empty() && StmtText.back() != ';') {
+ SourceLocation AfterEnd = Lexer::getLocForEndOfToken(
+ StmtRange.getEnd(), 0, SM, LangOpts);
+ Token Tok;
+ if (Lexer::getRawToken(AfterEnd, Tok, SM, LangOpts, false) ||
+ !Tok.is(tok::semi)) {
+ OS << ";";
+ }
+ }
+ }
+ OS << "\n";
+ }
+ }
+
+ OS << "}";
+ }
+
+ // Flush the stream to ensure all content is written to Replacement
+ OS.flush();
+
+ // Create the atomic change
+ AtomicChange Change(SM, StartLoc);
+
+ // Replace the entire switch statement
+ auto Err = Change.replace(SM, CharSourceRange::getTokenRange(SwitchRange),
+ Replacement);
+ if (Err)
+ return std::move(Err);
+
+ return AtomicChanges{std::move(Change)};
+}
+
diff --git a/clang/test/Refactor/SwitchToIf/basic.cpp b/clang/test/Refactor/SwitchToIf/basic.cpp
new file mode 100644
index 0000000000000..82a4ac64f603b
--- /dev/null
+++ b/clang/test/Refactor/SwitchToIf/basic.cpp
@@ -0,0 +1,22 @@
+// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s
+
+void foo(int x) {
+ switch (x) { // CHECK: Start refactoring here
+ case 1:
+ bar();
+ break;
+ case 2:
+ baz();
+ break;
+ default:
+ qux();
+ }
+}
+
+// CHECK: if (x == 1) {
+// CHECK-NEXT: bar();
+// CHECK-NEXT: } else if (x == 2) {
+// CHECK-NEXT: baz();
+// CHECK-NEXT: } else {
+// CHECK-NEXT: qux();
+// CHECK-NEXT: }
diff --git a/clang/test/Refactor/SwitchToIf/fallthrough.cpp b/clang/test/Refactor/SwitchToIf/fallthrough.cpp
new file mode 100644
index 0000000000000..8b02b44e9226b
--- /dev/null
+++ b/clang/test/Refactor/SwitchToIf/fallthrough.cpp
@@ -0,0 +1,23 @@
+// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s
+
+void g(int v) {
+ switch (v) {
+ case 3:
+ alpha();
+ [[fallthrough]];
+ case 4:
+ beta();
+ break;
+ default:
+ gamma();
+ }
+}
+
+// CHECK: if (v == 3) {
+// CHECK-NEXT: alpha();
+// CHECK-NEXT: beta();
+// CHECK-NEXT: } else if (v == 4) {
+// CHECK-NEXT: beta();
+// CHECK-NEXT: } else {
+// CHECK-NEXT: gamma();
+// CHECK-NEXT: }
diff --git a/clang/test/Refactor/SwitchToIf/multi_case.cpp b/clang/test/Refactor/SwitchToIf/multi_case.cpp
new file mode 100644
index 0000000000000..aa8dfba9b0c97
--- /dev/null
+++ b/clang/test/Refactor/SwitchToIf/multi_case.cpp
@@ -0,0 +1,23 @@
+// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s
+
+void test(int n) {
+ switch (n) {
+ case 1:
+ case 2:
+ handleSmall();
+ break;
+ case 10:
+ handleLarge();
+ break;
+ default:
+ handleOther();
+ }
+}
+
+// CHECK: if (n == 1 || n == 2) {
+// CHECK-NEXT: handleSmall();
+// CHECK-NEXT: } else if (n == 10) {
+// CHECK-NEXT: handleLarge();
+// CHECK-NEXT: } else {
+// CHECK-NEXT: handleOther();
+// CHECK-NEXT: }
diff --git a/clang/test/Refactor/SwitchToIf/no_default.cpp b/clang/test/Refactor/SwitchToIf/no_default.cpp
new file mode 100644
index 0000000000000..1f75de88a4c4e
--- /dev/null
+++ b/clang/test/Refactor/SwitchToIf/no_default.cpp
@@ -0,0 +1,18 @@
+// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s
+
+void f(int k) {
+ switch (k) {
+ case 5:
+ ping();
+ break;
+ case 7:
+ pong();
+ break;
+ }
+}
+
+// CHECK: if (k == 5) {
+// CHECK-NEXT: ping();
+// CHECK-NEXT: } else if (k == 7) {
+// CHECK-NEXT: pong();
+// CHECK-NEXT: }
|
|
Thank you for your patch! Could you please make it a clang-tidy check? Please see discussion in #166822 for reason. |
No description provided.