2424import net .sf .jsqlparser .expression .Function ;
2525import net .sf .jsqlparser .parser .CCJSqlParserUtil ;
2626import net .sf .jsqlparser .schema .Column ;
27+ import net .sf .jsqlparser .statement .Statement ;
28+ import net .sf .jsqlparser .statement .delete .Delete ;
2729import net .sf .jsqlparser .statement .select .OrderByElement ;
2830import net .sf .jsqlparser .statement .select .PlainSelect ;
2931import net .sf .jsqlparser .statement .select .Select ;
3032import net .sf .jsqlparser .statement .select .SelectExpressionItem ;
3133import net .sf .jsqlparser .statement .select .SelectItem ;
34+ import net .sf .jsqlparser .statement .update .Update ;
3235
3336import java .util .ArrayList ;
3437import java .util .Collections ;
5457public class JSqlParserQueryEnhancer implements QueryEnhancer {
5558
5659 private final DeclaredQuery query ;
60+ private final ParsedType parsedType ;
5761
5862 /**
5963 * @param query the query we want to enhance. Must not be {@literal null}.
6064 */
6165 public JSqlParserQueryEnhancer (DeclaredQuery query ) {
6266 this .query = query ;
67+ this .parsedType = detectParsedType ();
68+ }
69+
70+ /**
71+ * Detects what type of query is provided.
72+ *
73+ * @return the parsed type
74+ */
75+ private ParsedType detectParsedType () {
76+ try {
77+ Statement statement = CCJSqlParserUtil .parse (this .query .getQueryString ());
78+
79+ if (statement instanceof Update ) {
80+ return ParsedType .UPDATE ;
81+ } else if (statement instanceof Delete ) {
82+ return ParsedType .DELETE ;
83+ } else if (statement instanceof Select ) {
84+ return ParsedType .SELECT ;
85+ } else {
86+ return ParsedType .SELECT ;
87+ }
88+
89+ } catch (JSQLParserException e ) {
90+ throw new IllegalArgumentException ("The query you provided is not a valid SQL Query!" , e );
91+ }
6392 }
6493
6594 @ Override
6695 public String applySorting (Sort sort , @ Nullable String alias ) {
67-
6896 String queryString = query .getQueryString ();
6997 Assert .hasText (queryString , "Query must not be null or empty!" );
7098
99+ if (this .parsedType != ParsedType .SELECT ) {
100+ return queryString ;
101+ }
102+
71103 if (sort .isUnsorted ()) {
72104 return queryString ;
73105 }
@@ -120,6 +152,10 @@ private Set<String> getSelectionAliases(PlainSelect selectBody) {
120152 */
121153 Set <String > getSelectionAliases () {
122154
155+ if (this .parsedType != ParsedType .SELECT ) {
156+ return new HashSet <>();
157+ }
158+
123159 Select selectStatement = parseSelectStatement (this .query .getQueryString ());
124160 PlainSelect selectBody = (PlainSelect ) selectStatement .getSelectBody ();
125161 return this .getSelectionAliases (selectBody );
@@ -132,6 +168,9 @@ Set<String> getSelectionAliases() {
132168 * @return a {@literal Set} of aliases used in the query. Guaranteed to be not {@literal null}.
133169 */
134170 private Set <String > getJoinAliases (String query ) {
171+ if (this .parsedType != ParsedType .SELECT ) {
172+ return new HashSet <>();
173+ }
135174 return getJoinAliases ((PlainSelect ) parseSelectStatement (query ).getSelectBody ());
136175 }
137176
@@ -211,6 +250,10 @@ public String detectAlias() {
211250 @ Nullable
212251 private String detectAlias (String query ) {
213252
253+ if (this .parsedType != ParsedType .SELECT ) {
254+ return null ;
255+ }
256+
214257 Select selectStatement = parseSelectStatement (query );
215258 PlainSelect selectBody = (PlainSelect ) selectStatement .getSelectBody ();
216259 return detectAlias (selectBody );
@@ -233,6 +276,10 @@ private static String detectAlias(PlainSelect selectBody) {
233276 @ Override
234277 public String createCountQueryFor (@ Nullable String countProjection ) {
235278
279+ if (this .parsedType != ParsedType .SELECT ) {
280+ return this .query .getQueryString ();
281+ }
282+
236283 Assert .hasText (this .query .getQueryString (), "OriginalQuery must not be null or empty!" );
237284
238285 Select selectStatement = parseSelectStatement (this .query .getQueryString ());
@@ -278,6 +325,10 @@ public String createCountQueryFor(@Nullable String countProjection) {
278325 @ Override
279326 public String getProjection () {
280327
328+ if (this .parsedType != ParsedType .SELECT ) {
329+ return "" ;
330+ }
331+
281332 Assert .hasText (query .getQueryString (), "Query must not be null or empty!" );
282333
283334 Select selectStatement = parseSelectStatement (query .getQueryString ());
@@ -327,3 +378,15 @@ public DeclaredQuery getQuery() {
327378 return this .query ;
328379 }
329380}
381+
382+ /**
383+ * An enum to represent the top level parsed statement of the provided query.
384+ * <ul>
385+ * <li>{@code ParsedType.DELETE}: means the top level statement is {@link Delete}</li>
386+ * <li>{@code ParsedType.UPDATE}: means the top level statement is {@link Update}</li>
387+ * <li>{@code ParsedType.SELECT}: means the top level statement is {@link Select}</li>
388+ * </ul>
389+ */
390+ enum ParsedType {
391+ DELETE , UPDATE , SELECT ;
392+ }
0 commit comments