Skip to content

Commit b11b85a

Browse files
svenrienstraschauder
authored andcommitted
Add support for CASE statement in select and order by clauses.
Original pull request #1844
1 parent 1a860bf commit b11b85a

File tree

9 files changed

+303
-11
lines changed

9 files changed

+303
-11
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package org.springframework.data.relational.core.sql;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
import static java.util.stream.Collectors.joining;
7+
8+
/**
9+
* Case with one or more conditions expression.
10+
* <p>
11+
* Results in a rendered condition:
12+
* <pre>
13+
* CASE
14+
* WHEN condition1 THEN result1
15+
* WHEN condition2 THEN result2
16+
* ELSE result
17+
* END
18+
* </pre>
19+
* </p>
20+
*
21+
* @author Sven Rienstra
22+
* @since 3.4
23+
*/
24+
public class CaseExpression extends AbstractSegment implements Expression {
25+
private final List<When> whenList;
26+
private final Expression elseExpression;
27+
28+
private CaseExpression(List<When> whenList, Expression elseExpression) {
29+
30+
super(children(whenList, elseExpression));
31+
this.whenList = whenList;
32+
this.elseExpression = elseExpression;
33+
}
34+
35+
/**
36+
* Create CASE {@link Expression} with initial {@link When} condition.
37+
* @param condition initial {@link When} condition
38+
* @return the {@link CaseExpression}
39+
*/
40+
public static CaseExpression create(When condition) {
41+
return new CaseExpression(List.of(condition), null);
42+
}
43+
44+
/**
45+
* Add additional {@link When} condition
46+
* @param condition the {@link When} condition
47+
* @return the {@link CaseExpression}
48+
*/
49+
public CaseExpression when(When condition) {
50+
List<When> conditions = new ArrayList<>(this.whenList);
51+
conditions.add(condition);
52+
return new CaseExpression(conditions, elseExpression);
53+
}
54+
55+
/**
56+
* Add ELSE clause
57+
* @param elseExpression the {@link Expression} else value
58+
* @return the {@link CaseExpression}
59+
*/
60+
public CaseExpression elseExpression(Literal elseExpression) {
61+
return new CaseExpression(whenList, elseExpression);
62+
}
63+
64+
/**
65+
* @return the {@link When} conditions
66+
*/
67+
public List<When> getWhenList() {
68+
return whenList;
69+
}
70+
71+
/**
72+
* @return the ELSE {@link Literal} value
73+
*/
74+
public Expression getElseExpression() {
75+
return elseExpression;
76+
}
77+
78+
@Override
79+
public String toString() {
80+
return "CASE " + whenList.stream().map(When::toString).collect(joining(" ")) + (elseExpression != null ? " ELSE " + elseExpression : "") + " END";
81+
}
82+
83+
private static Segment[] children(List<When> whenList, Expression elseExpression) {
84+
85+
List<Segment> segments = new ArrayList<>();
86+
segments.addAll(whenList);
87+
88+
if (elseExpression != null) {
89+
segments.add(elseExpression);
90+
}
91+
92+
return segments.toArray(new Segment[segments.size()]);
93+
}
94+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.springframework.data.relational.core.sql;
2+
3+
/**
4+
* When segment for Case statement.
5+
* <p>
6+
* Results in a rendered condition: {@code WHEN <condition> THEN <value>}.
7+
* </p>
8+
*
9+
* @author Sven Rienstra
10+
* @since 3.4
11+
*/
12+
public class When extends AbstractSegment {
13+
private final Condition condition;
14+
private final Expression value;
15+
16+
private When(Condition condition, Expression value) {
17+
18+
super(condition, value);
19+
20+
this.condition = condition;
21+
this.value = value;
22+
}
23+
24+
/**
25+
* Creates a new {@link When} given two {@link Expression} condition and {@link Literal} value.
26+
*
27+
* @param condition the condition {@link Expression}.
28+
* @param value the {@link Literal} value.
29+
* @return the {@link When}.
30+
*/
31+
public static When when(Condition condition, Expression value) {
32+
return new When(condition, value);
33+
}
34+
35+
/**
36+
* @return the condition
37+
*/
38+
public Condition getCondition() {
39+
return condition;
40+
}
41+
42+
/**
43+
* @return the value
44+
*/
45+
public Expression getValue() {
46+
return value;
47+
}
48+
49+
@Override
50+
public String toString() {
51+
return "WHEN " + condition + " THEN " + value;
52+
}
53+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.springframework.data.relational.core.sql.render;
2+
3+
import org.springframework.data.relational.core.sql.CaseExpression;
4+
import org.springframework.data.relational.core.sql.Literal;
5+
import org.springframework.data.relational.core.sql.Visitable;
6+
import org.springframework.data.relational.core.sql.When;
7+
8+
/**
9+
* Renderer for {@link CaseExpression}.
10+
*
11+
* @author Sven Rienstra
12+
* @since 3.4
13+
*/
14+
public class CaseExpressionVisitor extends TypedSingleConditionRenderSupport<CaseExpression> implements PartRenderer {
15+
private final StringBuilder part = new StringBuilder();
16+
17+
CaseExpressionVisitor(RenderContext context) {
18+
super(context);
19+
}
20+
21+
@Override
22+
Delegation leaveNested(Visitable segment) {
23+
24+
if (hasDelegatedRendering()) {
25+
CharSequence renderedPart = consumeRenderedPart();
26+
27+
if (segment instanceof When) {
28+
part.append(" ");
29+
part.append(renderedPart);
30+
} else if (segment instanceof Literal<?>) {
31+
part.append(" ELSE ");
32+
part.append(renderedPart);
33+
}
34+
}
35+
36+
return super.leaveNested(segment);
37+
}
38+
39+
@Override
40+
Delegation enterMatched(CaseExpression segment) {
41+
42+
part.append("CASE");
43+
44+
return super.enterMatched(segment);
45+
}
46+
47+
@Override
48+
Delegation leaveMatched(CaseExpression segment) {
49+
50+
part.append(" END");
51+
52+
return super.leaveMatched(segment);
53+
}
54+
55+
@Override
56+
public CharSequence getRenderedPart() {
57+
return part;
58+
}
59+
}

spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ExpressionVisitor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ Delegation enterMatched(Expression segment) {
108108
CastVisitor visitor = new CastVisitor(context);
109109
partRenderer = visitor;
110110
return Delegation.delegateTo(visitor);
111+
} else if (segment instanceof CaseExpression) {
112+
CaseExpressionVisitor visitor = new CaseExpressionVisitor(context);
113+
partRenderer = visitor;
114+
return Delegation.delegateTo(visitor);
111115
} else {
112116
// works for literals and just and possibly more
113117
value = segment.toString();

spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
*/
1616
package org.springframework.data.relational.core.sql.render;
1717

18+
1819
import org.springframework.data.relational.core.sql.Column;
20+
import org.springframework.data.relational.core.sql.CaseExpression;
1921
import org.springframework.data.relational.core.sql.Expressions;
2022
import org.springframework.data.relational.core.sql.OrderByField;
2123
import org.springframework.data.relational.core.sql.SimpleFunction;
@@ -83,7 +85,7 @@ Delegation enterNested(Visitable segment) {
8385
return Delegation.delegateTo((SimpleFunctionVisitor)delegate);
8486
}
8587

86-
if (segment instanceof Expressions.SimpleExpression) {
88+
if (segment instanceof Expressions.SimpleExpression || segment instanceof CaseExpression) {
8789
delegate = new ExpressionVisitor(context);
8890
return Delegation.delegateTo((ExpressionVisitor)delegate);
8991
}

spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/TypedSingleConditionRenderSupport.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.springframework.data.relational.core.sql.Condition;
1919
import org.springframework.data.relational.core.sql.Expression;
2020
import org.springframework.data.relational.core.sql.Visitable;
21+
import org.springframework.data.relational.core.sql.When;
2122
import org.springframework.lang.Nullable;
2223
import org.springframework.util.Assert;
2324

@@ -26,6 +27,7 @@
2627
* delegate nested {@link Expression} and {@link Condition} rendering.
2728
*
2829
* @author Mark Paluch
30+
* @author Sven Rienstra
2931
* @since 1.1
3032
*/
3133
abstract class TypedSingleConditionRenderSupport<T extends Visitable> extends TypedSubtreeVisitor<T> {
@@ -40,8 +42,8 @@ abstract class TypedSingleConditionRenderSupport<T extends Visitable> extends Ty
4042
@Override
4143
Delegation enterNested(Visitable segment) {
4244

43-
if (segment instanceof Expression) {
44-
ExpressionVisitor visitor = new ExpressionVisitor(context);
45+
if (segment instanceof When) {
46+
WhenVisitor visitor = new WhenVisitor(context);
4547
current = visitor;
4648
return Delegation.delegateTo(visitor);
4749
}
@@ -52,6 +54,12 @@ Delegation enterNested(Visitable segment) {
5254
return Delegation.delegateTo(visitor);
5355
}
5456

57+
if (segment instanceof Expression) {
58+
ExpressionVisitor visitor = new ExpressionVisitor(context);
59+
current = visitor;
60+
return Delegation.delegateTo(visitor);
61+
}
62+
5563
throw new IllegalStateException("Cannot provide visitor for " + segment);
5664
}
5765

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package org.springframework.data.relational.core.sql.render;
2+
3+
import org.springframework.data.relational.core.sql.Visitable;
4+
import org.springframework.data.relational.core.sql.When;
5+
6+
/**
7+
* Renderer for {@link When} segments.
8+
*
9+
* @author Sven Rienstra
10+
* @since 3.4
11+
*/
12+
public class WhenVisitor extends TypedSingleConditionRenderSupport<When> implements PartRenderer {
13+
private final StringBuilder part = new StringBuilder();
14+
private boolean conditionRendered;
15+
16+
WhenVisitor(RenderContext context) {
17+
super(context);
18+
}
19+
20+
@Override
21+
Delegation leaveNested(Visitable segment) {
22+
23+
if (hasDelegatedRendering()) {
24+
25+
if (conditionRendered) {
26+
part.append(" THEN ");
27+
}
28+
29+
part.append(consumeRenderedPart());
30+
conditionRendered = true;
31+
}
32+
33+
return super.leaveNested(segment);
34+
}
35+
36+
@Override
37+
Delegation enterMatched(When segment) {
38+
39+
part.append("WHEN ");
40+
41+
return super.enterMatched(segment);
42+
}
43+
44+
@Override
45+
public CharSequence getRenderedPart() {
46+
return part;
47+
}
48+
}

spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitorUnitTests.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,7 @@
1818
import static org.assertj.core.api.Assertions.*;
1919

2020
import org.junit.jupiter.api.Test;
21-
import org.springframework.data.relational.core.sql.Column;
22-
import org.springframework.data.relational.core.sql.Expression;
23-
import org.springframework.data.relational.core.sql.Expressions;
24-
import org.springframework.data.relational.core.sql.OrderByField;
25-
import org.springframework.data.relational.core.sql.SQL;
26-
import org.springframework.data.relational.core.sql.Select;
27-
import org.springframework.data.relational.core.sql.SimpleFunction;
28-
import org.springframework.data.relational.core.sql.Table;
21+
import org.springframework.data.relational.core.sql.*;
2922

3023
import java.util.Arrays;
3124
import java.util.List;
@@ -129,4 +122,18 @@ void shouldRenderOrderBySimpleExpression() {
129122

130123
assertThat(visitor.getRenderedPart().toString()).isEqualTo("1 ASC");
131124
}
125+
126+
@Test
127+
void shouldRenderOrderByCase() {
128+
Table employee = SQL.table("employee").as("emp");
129+
Column column = employee.column("name");
130+
131+
CaseExpression caseExpression = CaseExpression.create(When.when(column.isNull(), SQL.literalOf(1))).elseExpression(SQL.literalOf(2));
132+
Select select = Select.builder().select(column).from(employee).orderBy(OrderByField.from(caseExpression).asc()).build();
133+
134+
OrderByClauseVisitor visitor = new OrderByClauseVisitor(new SimpleRenderContext(NamingStrategies.asIs()));
135+
select.visit(visitor);
136+
137+
assertThat(visitor.getRenderedPart().toString()).isEqualTo("CASE WHEN emp.name IS NULL THEN 1 ELSE 2 END ASC");
138+
}
132139
}

spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/SelectRendererUnitTests.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,23 @@ void asteriskOfAliasedTableUsesAlias() {
688688
assertThat(rendered).isEqualTo("SELECT e.*, e.id FROM employee e");
689689
}
690690

691+
@Test
692+
void rendersCaseExpression() {
693+
Table table = SQL.table("table");
694+
Column column = table.column("name");
695+
696+
CaseExpression caseExpression = CaseExpression.create(When.when(column.isNull(), SQL.literalOf(1))) //
697+
.when(When.when(column.isNotNull(), SQL.literalOf(2))) //
698+
.elseExpression(SQL.literalOf(3));
699+
700+
Select select = StatementBuilder.select(caseExpression) //
701+
.from(table) //
702+
.build();
703+
704+
String rendered = SqlRenderer.toString(select);
705+
assertThat(rendered).isEqualTo("SELECT CASE WHEN table.name IS NULL THEN 1 WHEN table.name IS NOT NULL THEN 2 ELSE 3 END FROM table");
706+
}
707+
691708
/**
692709
* Tests the rendering of analytic functions.
693710
*/

0 commit comments

Comments
 (0)