Skip to content

Commit df5199b

Browse files
committed
Add AOT support for Streamable wrapper support.
We now support streamable wrappers through AOT repositories. See #1620
1 parent e22bcbf commit df5199b

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

spring-data-cassandra/src/main/java/org/springframework/data/cassandra/repository/aot/CassandraCodeBlocks.java

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.jspecify.annotations.Nullable;
2424

2525
import org.springframework.core.annotation.MergedAnnotation;
26+
import org.springframework.core.convert.TypeDescriptor;
27+
import org.springframework.core.convert.support.DefaultConversionService;
2628
import org.springframework.data.cassandra.core.CassandraOperations;
2729
import org.springframework.data.cassandra.core.ExecutableSelectOperation;
2830
import org.springframework.data.cassandra.core.cql.QueryOptions;
@@ -718,13 +720,29 @@ CodeBlock build() {
718720

719721
if (streamableResult && Collection.class.isAssignableFrom(methodReturn.toClass())) {
720722
result = CodeBlock.of("$L.getContent()", result);
721-
} else if (!streamableResult && methodReturn.toClass().equals(Streamable.class)) {
723+
} else if (!streamableResult && (isStreamable(methodReturn) || isStreamableWrapper(methodReturn))) {
722724
result = CodeBlock.of("$T.of(($T) $L)", Streamable.class, Iterable.class, result);
723725
}
724726

725-
builder.addStatement(returnBuilder //
727+
if (isStreamableWrapper(methodReturn) && canConvert(Streamable.class, methodReturn)) {
728+
729+
Builder wrapperBuilder = CodeBlock.builder();
730+
731+
builder.addStatement("$1T<$2T> $3L = $4L", Streamable.class, actualReturnType,
732+
context.localVariable("streamable"), result);
733+
734+
builder.addStatement(
735+
"return ($1T) $2T.getSharedInstance().convert($3L, $4T.valueOf($5T.class), $4T.valueOf($6T.class))",
736+
methodReturn.getTypeName(), DefaultConversionService.class, context.localVariable("streamable"),
737+
TypeDescriptor.class, Streamable.class, methodReturn.toClass());
738+
739+
builder.add(wrapperBuilder.build());
740+
} else {
741+
742+
builder.addStatement(returnBuilder //
726743
.optional("($T) $L", methodReturn.getTypeName(), result) //
727744
.build());
745+
}
728746
} else {
729747

730748
CodeBlock executionBlock = execution.build();
@@ -733,16 +751,40 @@ CodeBlock build() {
733751
returnBuilder.whenPrimitiveOrBoxed(Integer.class, "(int) $L", executionBlock);
734752
} else if (streamableResult && Collection.class.isAssignableFrom(methodReturn.toClass())) {
735753
executionBlock = CodeBlock.of("$L.getContent()", executionBlock);
736-
} else if (!streamableResult && methodReturn.toClass().equals(Streamable.class)) {
754+
} else if (!streamableResult && (isStreamable(methodReturn) || isStreamableWrapper(methodReturn))) {
737755
executionBlock = CodeBlock.of("$T.of($L)", Streamable.class, executionBlock);
738756
}
739757

740-
builder.addStatement(returnBuilder.optional(executionBlock) //
758+
if (isStreamableWrapper(methodReturn) && canConvert(Streamable.class, methodReturn)) {
759+
760+
builder.addStatement("$1T<$2T> $3L = $4L", Streamable.class, actualReturnType,
761+
context.localVariable("streamable"), executionBlock);
762+
763+
builder.addStatement(
764+
"return ($1T) $2T.getSharedInstance().convert($3L, $4T.valueOf($5T.class), $4T.valueOf($6T.class))",
765+
methodReturn.getTypeName(), DefaultConversionService.class, context.localVariable("streamable"),
766+
TypeDescriptor.class, Streamable.class, methodReturn.toClass());
767+
} else {
768+
769+
builder.addStatement(returnBuilder.optional(executionBlock) //
741770
.build());
771+
}
742772
}
743773

744774
return builder.build();
745775
}
776+
777+
private boolean canConvert(Class<?> from, MethodReturn methodReturn) {
778+
return DefaultConversionService.getSharedInstance().canConvert(from, methodReturn.toClass());
779+
}
780+
781+
private static boolean isStreamable(MethodReturn methodReturn) {
782+
return methodReturn.toClass().equals(Streamable.class);
783+
}
784+
785+
private static boolean isStreamableWrapper(MethodReturn methodReturn) {
786+
return !isStreamable(methodReturn) && Streamable.class.isAssignableFrom(methodReturn.toClass());
787+
}
746788
}
747789

748790
}

spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/aot/CassandraRepositoryContributorIntegrationTests.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,24 @@ void shouldConvertResultToStreamableWhenPageableParameterIsUsed() {
216216
.containsExactly("Flynn (Walter Jr.)", "Skyler");
217217
}
218218

219+
@Test // GH-1620
220+
void shouldConvertResultToStreamableWrapper() {
221+
222+
assertThat(fragment.findWrapperByLastname("White")) //
223+
.isInstanceOf(PersonRepository.People.class) //
224+
.extracting(Person::getFirstname) //
225+
.hasSize(3);
226+
}
227+
228+
@Test // GH-1620
229+
void shouldConvertResultToStreamableWrapperWhenPageableParameterIsUsed() {
230+
231+
assertThat(fragment.findWrapperByLastname("White", PageRequest.of(0, 2, Sort.by("firstname"))))
232+
.isInstanceOf(PersonRepository.People.class) //
233+
.extracting(Person::getFirstname) //
234+
.containsExactly("Flynn (Walter Jr.)", "Skyler");
235+
}
236+
219237
@Test // GH-1566
220238
void shouldFindByFirstnameContains() {
221239

@@ -331,6 +349,15 @@ void shouldQueryDeclaredWindow() {
331349
assertThat(second.hasNext()).isFalse();
332350
}
333351

352+
@Test // GH-1620
353+
void shouldQueryDeclaredStreamable() {
354+
355+
PersonRepository.People people = fragment.findDeclaredWrapperByLastname("White",
356+
PageRequest.of(0, 2, Sort.by("firstname")));
357+
358+
assertThat(people).hasSize(2);
359+
}
360+
334361
@Test // GH-1566
335362
void shouldFindNamedQuery() {
336363

spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/aot/PersonRepository.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.data.cassandra.repository.aot;
1717

18+
import java.util.Iterator;
1819
import java.util.List;
1920
import java.util.Map;
2021
import java.util.Optional;
@@ -38,10 +39,10 @@
3839
import org.springframework.data.domain.Vector;
3940
import org.springframework.data.domain.Window;
4041
import org.springframework.data.repository.CrudRepository;
42+
import org.springframework.data.util.Streamable;
4143

4244
import com.datastax.oss.driver.api.core.DefaultConsistencyLevel;
4345
import com.datastax.oss.driver.api.core.cql.ResultSet;
44-
import org.springframework.data.util.Streamable;
4546

4647
/**
4748
* AOT repository interface for {@link Person} entities.
@@ -68,6 +69,10 @@ public interface PersonRepository extends CrudRepository<Person, String> {
6869

6970
Streamable<Person> streamByLastname(String lastname, Pageable pageable);
7071

72+
People findWrapperByLastname(String lastname);
73+
74+
People findWrapperByLastname(String lastname, Pageable pageable);
75+
7176
List<Person> findByLastnameOrderByFirstnameAsc(String lastname);
7277

7378
Person findByFirstnameStartsWith(String prefix);
@@ -122,6 +127,9 @@ public interface PersonRepository extends CrudRepository<Person, String> {
122127
Window<Person> findDeclaredWindowByLastname(String lastname, ScrollPosition scrollPosition, int sliceLimit,
123128
Limit pageSize);
124129

130+
@Query(value = "select * from person where lastname = ?0 LIMIT 3")
131+
People findDeclaredWrapperByLastname(String lastname, Pageable pageable);
132+
125133
// -------------------------------------------------------------------------
126134
// Value Expressions
127135
// -------------------------------------------------------------------------
@@ -240,4 +248,18 @@ public void setLastname(String lastname) {
240248
this.lastname = lastname;
241249
}
242250
}
251+
252+
class People implements Streamable<Person> {
253+
254+
private final Streamable<Person> people;
255+
256+
public People(Streamable<Person> people) {
257+
this.people = people;
258+
}
259+
260+
@Override
261+
public Iterator<Person> iterator() {
262+
return people.iterator();
263+
}
264+
}
243265
}

0 commit comments

Comments
 (0)