diff --git a/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h b/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h new file mode 100644 index 0000000000000..aa46e7a9b52bd --- /dev/null +++ b/clang/include/clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h @@ -0,0 +1,46 @@ +//===--- SwitchToIf.h - Switch to if refactoring -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H +#define LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H + +#include "clang/Tooling/Refactoring/ASTSelection.h" +#include "clang/Tooling/Refactoring/RefactoringActionRules.h" + +namespace clang { +class SwitchStmt; + +namespace tooling { + +/// A "Switch to If" refactoring converts a switch statement into an if-else +/// chain. +class SwitchToIf final : public SourceChangeRefactoringRule { +public: + /// Initiates the switch-to-if refactoring operation. + /// + /// \param Selection The selected AST node, which should be a switch statement. + static Expected + initiate(RefactoringRuleContext &Context, + SelectedASTNode Selection); + + static const RefactoringDescriptor &describe(); + +private: + SwitchToIf(const SwitchStmt *Switch) : TheSwitch(Switch) {} + + Expected + createSourceReplacements(RefactoringRuleContext &Context) override; + + const SwitchStmt *TheSwitch; +}; + +} // end namespace tooling +} // end namespace clang + +#endif // LLVM_CLANG_TOOLING_REFACTORING_SWITCHTOIF_SWITCHTOIF_H + diff --git a/clang/lib/Tooling/Refactoring/CMakeLists.txt b/clang/lib/Tooling/Refactoring/CMakeLists.txt index d3077be8810aa..35b806177fb7e 100644 --- a/clang/lib/Tooling/Refactoring/CMakeLists.txt +++ b/clang/lib/Tooling/Refactoring/CMakeLists.txt @@ -13,6 +13,7 @@ add_clang_library(clangToolingRefactoring Rename/USRFinder.cpp Rename/USRFindingAction.cpp Rename/USRLocFinder.cpp + SwitchToIf/SwitchToIf.cpp LINK_LIBS clangAST diff --git a/clang/lib/Tooling/Refactoring/RefactoringActions.cpp b/clang/lib/Tooling/Refactoring/RefactoringActions.cpp index bf98941f568b3..1f6f3f641aa8c 100644 --- a/clang/lib/Tooling/Refactoring/RefactoringActions.cpp +++ b/clang/lib/Tooling/Refactoring/RefactoringActions.cpp @@ -10,6 +10,7 @@ #include "clang/Tooling/Refactoring/RefactoringAction.h" #include "clang/Tooling/Refactoring/RefactoringOptions.h" #include "clang/Tooling/Refactoring/Rename/RenamingAction.h" +#include "clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h" namespace clang { namespace tooling { @@ -93,6 +94,24 @@ class LocalRename final : public RefactoringAction { } }; +class SwitchToIfRefactoring final : public RefactoringAction { +public: + StringRef getCommand() const override { return "switch-to-if"; } + + StringRef getDescription() const override { + return "Converts a switch statement into an if-else chain"; + } + + /// Returns a set of refactoring actions rules that are defined by this + /// action. + RefactoringActionRules createActionRules() const override { + RefactoringActionRules Rules; + Rules.push_back(createRefactoringActionRule( + ASTSelectionRequirement())); + return Rules; + } +}; + } // end anonymous namespace std::vector> createRefactoringActions() { @@ -100,6 +119,7 @@ std::vector> createRefactoringActions() { Actions.push_back(std::make_unique()); Actions.push_back(std::make_unique()); + Actions.push_back(std::make_unique()); return Actions; } diff --git a/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp b/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp new file mode 100644 index 0000000000000..cc6d30dc6d4ce --- /dev/null +++ b/clang/lib/Tooling/Refactoring/SwitchToIf/SwitchToIf.cpp @@ -0,0 +1,316 @@ +//===--- SwitchToIf.cpp - Switch to if refactoring ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Tooling/Refactoring/SwitchToIf/SwitchToIf.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Expr.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/SourceManager.h" +#include "clang/Lex/Lexer.h" +#include "clang/Tooling/Refactoring/AtomicChange.h" +#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h" +#include "clang/Tooling/Refactoring/RefactoringRuleContext.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +using namespace clang; +using namespace tooling; + +namespace { + +/// Returns the source text for the given expression. +std::string getExprText(const Expr *E, const SourceManager &SM, + const LangOptions &LangOpts) { + SourceRange Range = E->getSourceRange(); + return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), SM, + LangOpts) + .str(); +} + +/// Returns the source text for a range. +std::string getSourceText(SourceRange Range, const SourceManager &SM, + const LangOptions &LangOpts) { + return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), SM, + LangOpts) + .str(); +} + +/// Returns true if the statement is a break statement. +bool isBreakStmt(const Stmt *S) { + return isa(S); +} + +/// Recursively collects all statements from a statement, removing breaks. +/// This handles compound statements and stops at the first break. +void collectStatementsWithoutBreaks(const Stmt *S, + SmallVector &Result, + const SourceManager &SM, + const LangOptions &LangOpts) { + if (!S) + return; + + if (isBreakStmt(S)) + return; + + if (const CompoundStmt *CS = dyn_cast(S)) { + // Process each statement in the compound statement + for (const Stmt *Child : CS->body()) { + if (isBreakStmt(Child)) { + // Stop at first break + break; + } + collectStatementsWithoutBreaks(Child, Result, SM, LangOpts); + } + } else { + // For non-compound statements, add them directly + Result.push_back(S); + } +} + +/// Gets the statements from a case/default, removing breaks. +SmallVector getCaseStatements(const SwitchCase *SC, + const SourceManager &SM, + const LangOptions &LangOpts) { + SmallVector Result; + const Stmt *SubStmt = SC->getSubStmt(); + if (!SubStmt) + return Result; + + collectStatementsWithoutBreaks(SubStmt, Result, SM, LangOpts); + return Result; +} + +} // end anonymous namespace + +const RefactoringDescriptor &SwitchToIf::describe() { + static const RefactoringDescriptor Descriptor = { + "switch-to-if", + "Switch to If", + "Converts a switch statement into an if-else chain", + }; + return Descriptor; +} + +Expected +SwitchToIf::initiate(RefactoringRuleContext &Context, + SelectedASTNode Selection) { + // Find the SwitchStmt in the selection + const SwitchStmt *Switch = nullptr; + + // Helper lambda to recursively search for SwitchStmt + std::function findSwitch = + [&](const SelectedASTNode &Node) -> const SwitchStmt * { + if (const SwitchStmt *S = Node.Node.get()) { + return S; + } + // Search in children + for (const SelectedASTNode &Child : Node.Children) { + if (const SwitchStmt *S = findSwitch(Child)) { + return S; + } + } + return nullptr; + }; + + Switch = findSwitch(Selection); + + if (!Switch) { + return Context.createDiagnosticError( + Context.getSelectionRange().getBegin(), + diag::err_refactor_selection_invalid_ast); + } + + // Validate that the switch has at least one case + if (!Switch->getSwitchCaseList()) { + return Context.createDiagnosticError( + Switch->getSwitchLoc(), + diag::err_refactor_selection_invalid_ast); + } + + return SwitchToIf(Switch); +} + +Expected +SwitchToIf::createSourceReplacements(RefactoringRuleContext &Context) { + ASTContext &AST = Context.getASTContext(); + SourceManager &SM = AST.getSourceManager(); + const LangOptions &LangOpts = AST.getLangOpts(); + + const SwitchStmt *Switch = TheSwitch; + const Expr *Cond = Switch->getCond(); + + // Get the full source range of the switch statement + SourceLocation StartLoc = Switch->getBeginLoc(); + SourceLocation EndLoc = Switch->getEndLoc(); + + // Find the actual end location (closing brace) + if (const Stmt *Body = Switch->getBody()) { + EndLoc = Body->getEndLoc(); + } + + SourceRange SwitchRange(StartLoc, EndLoc); + + // Build the if-else chain + std::string Replacement; + llvm::raw_string_ostream OS(Replacement); + + std::string CondText = getExprText(Cond, SM, LangOpts); + + // Handle init statement if present + if (Switch->getInit()) { + std::string InitText = getSourceText(Switch->getInit()->getSourceRange(), + SM, LangOpts); + OS << InitText << " "; + } + + // Handle condition variable if present + if (Switch->getConditionVariableDeclStmt()) { + std::string VarText = getSourceText( + Switch->getConditionVariableDeclStmt()->getSourceRange(), SM, LangOpts); + OS << VarText << " "; + } + + bool First = true; + const SwitchCase *DefaultCase = nullptr; + SmallVector Cases; + + // Collect all cases and find default + for (const SwitchCase *SC = Switch->getSwitchCaseList(); SC; + SC = SC->getNextSwitchCase()) { + if (isa(SC)) { + DefaultCase = SC; + } else { + Cases.push_back(SC); + } + } + + // Process cases + for (const SwitchCase *Case : Cases) { + if (First) { + OS << "if ("; + First = false; + } else { + OS << " else if ("; + } + + const CaseStmt *CS = cast(Case); + const Expr *LHS = CS->getLHS(); + + // Handle GNU case ranges + if (CS->caseStmtIsGNURange()) { + const Expr *RHS = CS->getRHS(); + std::string LHSText = getExprText(LHS, SM, LangOpts); + std::string RHSText = getExprText(RHS, SM, LangOpts); + OS << CondText << " >= " << LHSText << " && " << CondText << " <= " + << RHSText; + } else { + std::string CaseValue = getExprText(LHS, SM, LangOpts); + OS << CondText << " == " << CaseValue; + } + + OS << ") {\n"; + + // Get statements from this case (without breaks) + SmallVector Statements = getCaseStatements(Case, SM, LangOpts); + + // Print statements + if (Statements.empty()) { + // Empty case - just add a blank line or comment + OS << " // empty case\n"; + } else { + for (const Stmt *S : Statements) { + SourceRange StmtRange = S->getSourceRange(); + std::string StmtText = getSourceText(StmtRange, SM, LangOpts); + + // Indent the statement + OS << " " << StmtText; + + // For compound statements, they already have their own braces + // For other statements, ensure proper termination + if (!isa(S) && !isa(S) && !isa(S) && + !isa(S) && !isa(S) && !isa(S) && + !isa(S) && !isa(S) && !isa(S)) { + // Check if statement already ends with semicolon by looking at the + // source text + if (!StmtText.empty() && StmtText.back() != ';') { + // Try to get the token after the statement + SourceLocation AfterEnd = Lexer::getLocForEndOfToken( + StmtRange.getEnd(), 0, SM, LangOpts); + Token Tok; + if (Lexer::getRawToken(AfterEnd, Tok, SM, LangOpts, false) || + !Tok.is(tok::semi)) { + OS << ";"; + } + } + } + OS << "\n"; + } + } + + OS << "}"; + } + + // Process default case + if (DefaultCase) { + if (First) { + OS << "if (1) { // default case\n"; + First = false; + } else { + OS << " else { // default case\n"; + } + + SmallVector Statements = getCaseStatements(DefaultCase, SM, LangOpts); + + if (Statements.empty()) { + OS << " // empty default case\n"; + } else { + for (const Stmt *S : Statements) { + SourceRange StmtRange = S->getSourceRange(); + std::string StmtText = getSourceText(StmtRange, SM, LangOpts); + + OS << " " << StmtText; + + if (!isa(S) && !isa(S) && !isa(S) && + !isa(S) && !isa(S) && !isa(S) && + !isa(S) && !isa(S) && !isa(S)) { + if (!StmtText.empty() && StmtText.back() != ';') { + SourceLocation AfterEnd = Lexer::getLocForEndOfToken( + StmtRange.getEnd(), 0, SM, LangOpts); + Token Tok; + if (Lexer::getRawToken(AfterEnd, Tok, SM, LangOpts, false) || + !Tok.is(tok::semi)) { + OS << ";"; + } + } + } + OS << "\n"; + } + } + + OS << "}"; + } + + // Flush the stream to ensure all content is written to Replacement + OS.flush(); + + // Create the atomic change + AtomicChange Change(SM, StartLoc); + + // Replace the entire switch statement + auto Err = Change.replace(SM, CharSourceRange::getTokenRange(SwitchRange), + Replacement); + if (Err) + return std::move(Err); + + return AtomicChanges{std::move(Change)}; +} + diff --git a/clang/test/Refactor/SwitchToIf/basic.cpp b/clang/test/Refactor/SwitchToIf/basic.cpp new file mode 100644 index 0000000000000..82a4ac64f603b --- /dev/null +++ b/clang/test/Refactor/SwitchToIf/basic.cpp @@ -0,0 +1,22 @@ +// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s + +void foo(int x) { + switch (x) { // CHECK: Start refactoring here + case 1: + bar(); + break; + case 2: + baz(); + break; + default: + qux(); + } +} + +// CHECK: if (x == 1) { +// CHECK-NEXT: bar(); +// CHECK-NEXT: } else if (x == 2) { +// CHECK-NEXT: baz(); +// CHECK-NEXT: } else { +// CHECK-NEXT: qux(); +// CHECK-NEXT: } diff --git a/clang/test/Refactor/SwitchToIf/fallthrough.cpp b/clang/test/Refactor/SwitchToIf/fallthrough.cpp new file mode 100644 index 0000000000000..8b02b44e9226b --- /dev/null +++ b/clang/test/Refactor/SwitchToIf/fallthrough.cpp @@ -0,0 +1,23 @@ +// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s + +void g(int v) { + switch (v) { + case 3: + alpha(); + [[fallthrough]]; + case 4: + beta(); + break; + default: + gamma(); + } +} + +// CHECK: if (v == 3) { +// CHECK-NEXT: alpha(); +// CHECK-NEXT: beta(); +// CHECK-NEXT: } else if (v == 4) { +// CHECK-NEXT: beta(); +// CHECK-NEXT: } else { +// CHECK-NEXT: gamma(); +// CHECK-NEXT: } diff --git a/clang/test/Refactor/SwitchToIf/multi_case.cpp b/clang/test/Refactor/SwitchToIf/multi_case.cpp new file mode 100644 index 0000000000000..aa8dfba9b0c97 --- /dev/null +++ b/clang/test/Refactor/SwitchToIf/multi_case.cpp @@ -0,0 +1,23 @@ +// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s + +void test(int n) { + switch (n) { + case 1: + case 2: + handleSmall(); + break; + case 10: + handleLarge(); + break; + default: + handleOther(); + } +} + +// CHECK: if (n == 1 || n == 2) { +// CHECK-NEXT: handleSmall(); +// CHECK-NEXT: } else if (n == 10) { +// CHECK-NEXT: handleLarge(); +// CHECK-NEXT: } else { +// CHECK-NEXT: handleOther(); +// CHECK-NEXT: } diff --git a/clang/test/Refactor/SwitchToIf/no_default.cpp b/clang/test/Refactor/SwitchToIf/no_default.cpp new file mode 100644 index 0000000000000..1f75de88a4c4e --- /dev/null +++ b/clang/test/Refactor/SwitchToIf/no_default.cpp @@ -0,0 +1,18 @@ +// RUN: clang-refactor -action switch-to-if %s -- 2>&1 | FileCheck %s + +void f(int k) { + switch (k) { + case 5: + ping(); + break; + case 7: + pong(); + break; + } +} + +// CHECK: if (k == 5) { +// CHECK-NEXT: ping(); +// CHECK-NEXT: } else if (k == 7) { +// CHECK-NEXT: pong(); +// CHECK-NEXT: }