Skip to content

Commit b46e342

Browse files
committed
Fix COUNT/EXISTS projections for entities without an identifier.
We now issue a COUNT(1) respective SELECT 1 for COUNT queries and EXISTS queries for entities that do not specify an identifier. Previously these query projections could fail because of empty select lists. Closes #1310
1 parent 92e77a4 commit b46e342

File tree

5 files changed

+100
-21
lines changed

5 files changed

+100
-21
lines changed

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ Mono<Long> doCount(Query query, Class<?> entityClass, SqlIdentifier tableName) {
298298

299299
Expression countExpression = entity.hasIdProperty()
300300
? table.column(entity.getRequiredIdProperty().getColumnName())
301-
: Expressions.asterisk();
301+
: Expressions.just("1");
302302
return spec.withProjection(Functions.count(countExpression));
303303
});
304304

@@ -333,13 +333,14 @@ Mono<Boolean> doExists(Query query, Class<?> entityClass, SqlIdentifier tableNam
333333
RelationalPersistentEntity<?> entity = getRequiredEntity(entityClass);
334334
StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass);
335335

336-
SqlIdentifier columnName = entity.hasIdProperty() ? entity.getRequiredIdProperty().getColumnName()
337-
: SqlIdentifier.unquoted("*");
336+
StatementMapper.SelectSpec selectSpec = statementMapper.createSelect(tableName).limit(1);
337+
if (entity.hasIdProperty()) {
338+
selectSpec = selectSpec //
339+
.withProjection(entity.getRequiredIdProperty().getColumnName());
338340

339-
StatementMapper.SelectSpec selectSpec = statementMapper //
340-
.createSelect(tableName) //
341-
.withProjection(columnName) //
342-
.limit(1);
341+
} else {
342+
selectSpec = selectSpec.withProjection(Expressions.just("1"));
343+
}
343344

344345
Optional<CriteriaDefinition> criteria = query.getCriteria();
345346
if (criteria.isPresent()) {

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ public List<OrderByField> getMappedSort(Table table, Sort sort, @Nullable Relati
153153
*/
154154
public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity<?> entity) {
155155

156-
if (entity == null || expression instanceof AsteriskFromTable) {
156+
if (entity == null || expression instanceof AsteriskFromTable
157+
|| expression instanceof Expressions.SimpleExpression) {
157158
return expression;
158159
}
159160

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
import org.springframework.data.domain.Sort;
2626
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
2727
import org.springframework.data.r2dbc.core.StatementMapper;
28-
import org.springframework.data.relational.repository.Lock;
2928
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
3029
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
3130
import org.springframework.data.relational.core.query.Criteria;
32-
import org.springframework.data.relational.core.sql.*;
31+
import org.springframework.data.relational.core.sql.Column;
32+
import org.springframework.data.relational.core.sql.Expression;
33+
import org.springframework.data.relational.core.sql.Expressions;
34+
import org.springframework.data.relational.core.sql.Functions;
35+
import org.springframework.data.relational.core.sql.SqlIdentifier;
36+
import org.springframework.data.relational.core.sql.Table;
37+
import org.springframework.data.relational.repository.Lock;
3338
import org.springframework.data.relational.repository.query.RelationalEntityMetadata;
3439
import org.springframework.data.relational.repository.query.RelationalParameterAccessor;
3540
import org.springframework.data.relational.repository.query.RelationalQueryCreator;
@@ -164,18 +169,14 @@ private Expression[] getSelectProjection() {
164169
expressions.add(column);
165170
}
166171

167-
} else if (tree.isExistsProjection()) {
168-
169-
expressions = dataAccessStrategy.getIdentifierColumns(entityToRead).stream() //
170-
.map(table::column) //
171-
.collect(Collectors.toList());
172-
} else if (tree.isCountProjection()) {
172+
} else if (tree.isExistsProjection() || tree.isCountProjection()) {
173173

174174
Expression countExpression = entityMetadata.getTableEntity().hasIdProperty()
175175
? table.column(entityMetadata.getTableEntity().getRequiredIdProperty().getColumnName())
176-
: Expressions.asterisk();
176+
: Expressions.just("1");
177177

178-
expressions = Collections.singletonList(Functions.count(countExpression));
178+
expressions = Collections
179+
.singletonList(tree.isCountProjection() ? Functions.count(countExpression) : countExpression);
179180
} else {
180181
expressions = dataAccessStrategy.getAllColumns(entityToRead).stream() //
181182
.map(table::column) //

spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,30 @@ void shouldProjectExistsResult() {
122122
.verifyComplete();
123123
}
124124

125+
@Test // gh-1310
126+
void shouldProjectExistsResultWithoutId() {
127+
128+
MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Object.class, null).build()).build();
129+
130+
recorder.addStubbing(s -> s.startsWith("SELECT 1"), result);
131+
132+
entityTemplate.select(WithoutId.class).exists() //
133+
.as(StepVerifier::create) //
134+
.expectNext(true).verifyComplete();
135+
}
136+
137+
@Test // gh-1310
138+
void shouldProjectCountResultWithoutId() {
139+
140+
MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build();
141+
142+
recorder.addStubbing(s -> s.startsWith("SELECT COUNT(1)"), result);
143+
144+
entityTemplate.select(WithoutId.class).count() //
145+
.as(StepVerifier::create) //
146+
.expectNext(1L).verifyComplete();
147+
}
148+
125149
@Test // gh-469
126150
void shouldExistsByCriteria() {
127151

@@ -477,6 +501,12 @@ void updateShouldInvokeCallback() {
477501
Parameter.from("before-save"));
478502
}
479503

504+
@Value
505+
static class WithoutId {
506+
507+
String name;
508+
}
509+
480510
@Value
481511
@With
482512
static class Person {

spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.mockito.junit.jupiter.MockitoExtension;
3939
import org.mockito.junit.jupiter.MockitoSettings;
4040
import org.mockito.quality.Strictness;
41+
4142
import org.springframework.beans.factory.annotation.Value;
4243
import org.springframework.data.annotation.Id;
4344
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@@ -49,10 +50,10 @@
4950
import org.springframework.data.r2dbc.dialect.DialectResolver;
5051
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
5152
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
52-
import org.springframework.data.relational.repository.Lock;
5353
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
5454
import org.springframework.data.relational.core.mapping.Table;
5555
import org.springframework.data.relational.core.sql.LockMode;
56+
import org.springframework.data.relational.repository.Lock;
5657
import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor;
5758
import org.springframework.data.repository.Repository;
5859
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
@@ -748,6 +749,32 @@ void bindsParametersFromPublisher() throws Exception {
748749
verify(bindTarget, times(1)).bind(0, "John");
749750
}
750751

752+
@Test // GH-1310
753+
void createsQueryWithoutIdForCountProjection() throws Exception {
754+
755+
R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "countByFirstName", String.class);
756+
PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy);
757+
PreparedOperation<?> query = createQuery(queryMethod, r2dbcQuery, "John");
758+
759+
PreparedOperationAssert.assertThat(query) //
760+
.selects("COUNT(1)") //
761+
.from(TABLE) //
762+
.where(TABLE + ".first_name = $1");
763+
}
764+
765+
@Test // GH-1310
766+
void createsQueryWithoutIdForExistsProjection() throws Exception {
767+
768+
R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "existsByFirstName", String.class);
769+
PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy);
770+
PreparedOperation<?> query = createQuery(queryMethod, r2dbcQuery, "John");
771+
772+
PreparedOperationAssert.assertThat(query) //
773+
.selects("1") //
774+
.from(TABLE) //
775+
.where(TABLE + ".first_name = $1 LIMIT 1");
776+
}
777+
751778
private PreparedOperation<?> createQuery(R2dbcQueryMethod queryMethod, PartTreeR2dbcQuery r2dbcQuery,
752779
Object... parameters) {
753780
return createQuery(r2dbcQuery, getAccessor(queryMethod, parameters));
@@ -759,8 +786,13 @@ private PreparedOperation<?> createQuery(PartTreeR2dbcQuery r2dbcQuery,
759786
}
760787

761788
private R2dbcQueryMethod getQueryMethod(String methodName, Class<?>... parameterTypes) throws Exception {
762-
Method method = UserRepository.class.getMethod(methodName, parameterTypes);
763-
return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(UserRepository.class),
789+
return getQueryMethod(UserRepository.class, methodName, parameterTypes);
790+
}
791+
792+
private R2dbcQueryMethod getQueryMethod(Class<?> repository, String methodName, Class<?>... parameterTypes)
793+
throws Exception {
794+
Method method = repository.getMethod(methodName, parameterTypes);
795+
return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(repository),
764796
new SpelAwareProxyProjectionFactory(), mappingContext);
765797
}
766798

@@ -946,6 +978,13 @@ interface UserRepository extends Repository<User, Long> {
946978

947979
}
948980

981+
interface WithoutIdRepository extends Repository<WithoutId, Long> {
982+
983+
Mono<Boolean> existsByFirstName(String firstName);
984+
985+
Mono<Long> countByFirstName(String firstName);
986+
}
987+
949988
@Table("users")
950989
@Data
951990
private static class User {
@@ -958,6 +997,13 @@ private static class User {
958997
private Boolean active;
959998
}
960999

1000+
@Table("users")
1001+
@Data
1002+
private static class WithoutId {
1003+
1004+
private String firstName;
1005+
}
1006+
9611007
interface UserProjection {
9621008

9631009
String getFirstName();

0 commit comments

Comments
 (0)