Skip to content

Commit 1be3b9e

Browse files
Support custom Streamable return type in AOT repository.
This commit uses a conversion service to convert custom streamable types. See: #5089
1 parent 828c7e2 commit 1be3b9e

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.bson.Document;
2121
import org.jspecify.annotations.Nullable;
2222
import org.springframework.core.annotation.MergedAnnotation;
23+
import org.springframework.core.convert.TypeDescriptor;
24+
import org.springframework.core.convert.support.DefaultConversionService;
2325
import org.springframework.data.mapping.model.SimpleTypeHolder;
2426
import org.springframework.data.mongodb.repository.ReadPreference;
2527
import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationCodeBlockBuilder;
@@ -38,6 +40,7 @@
3840
import org.springframework.data.util.Streamable;
3941
import org.springframework.javapoet.CodeBlock;
4042
import org.springframework.javapoet.CodeBlock.Builder;
43+
import org.springframework.util.ClassUtils;
4144
import org.springframework.util.NumberUtils;
4245
import org.springframework.util.StringUtils;
4346

@@ -238,9 +241,21 @@ static void appendReadPreference(AotQueryMethodGenerationContext context, Builde
238241
* {@link MethodReturn} indicates so.
239242
*/
240243
public static CodeBlock potentiallyWrapStreamable(MethodReturn methodReturn, CodeBlock returningIterable) {
241-
return methodReturn.toClass().equals(Streamable.class)
242-
? CodeBlock.of("$T.of($L)", Streamable.class, returningIterable)
243-
: returningIterable;
244+
245+
Class<?> returnType = methodReturn.toClass();
246+
247+
if (returnType.equals(Streamable.class)) {
248+
return CodeBlock.of("$T.of($L)", Streamable.class, returningIterable);
249+
}
250+
251+
if (ClassUtils.isAssignable(Streamable.class, returnType)) {
252+
253+
return CodeBlock.of(
254+
"($1T) $2T.getSharedInstance().convert($3T.of($4L), $5T.valueOf($3T.class), $5T.valueOf($1T.class))",
255+
returnType, DefaultConversionService.class, Streamable.class, returningIterable, TypeDescriptor.class);
256+
}
257+
258+
return returningIterable;
244259
}
245260

246261
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.springframework.data.mongodb.core.query.Query;
7474
import org.springframework.data.mongodb.core.query.Update;
7575
import org.springframework.data.mongodb.repository.Person.Sex;
76+
import org.springframework.data.mongodb.repository.PersonRepository.Persons;
7677
import org.springframework.data.mongodb.repository.SampleEvaluationContextExtension.SampleSecurityContextHolder;
7778
import org.springframework.data.mongodb.test.util.DirtiesStateExtension;
7879
import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesState;
@@ -324,6 +325,17 @@ void streamPersonByAddressCorrectly() {
324325
assertThat(result).hasSize(1).contains(dave);
325326
}
326327

328+
@Test // GH-5089
329+
void useCustomReturnTypeImplementingStreamable() {
330+
331+
Address address = new Address("Foo Street 1", "C0123", "Bar");
332+
dave.setAddress(address);
333+
repository.save(dave);
334+
335+
Persons result = repository.streamPersonsByAddress(address);
336+
assertThat(result).hasSize(1).contains(dave);
337+
}
338+
327339
@Test // GH-5089
328340
void streamPersonByAddressCorrectlyWhenPaged() {
329341

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.Collection;
1919
import java.util.Date;
20+
import java.util.Iterator;
2021
import java.util.List;
2122
import java.util.Optional;
2223
import java.util.UUID;
@@ -214,6 +215,8 @@ Window<Person> findByLastnameLikeOrderByLastnameAscFirstnameAsc(Pattern lastname
214215

215216
Streamable<Person> streamByAddress(Address address);
216217

218+
Persons streamPersonsByAddress(Address address);
219+
217220
Streamable<Person> streamByAddress(Address address, Pageable pageable);
218221

219222
List<Person> findByAddressZipCode(String zipCode);
@@ -502,4 +505,17 @@ Person findPersonByManyArguments(String firstname, String lastname, String email
502505

503506
List<Person> findBySpiritAnimal(User user);
504507

508+
class Persons implements Streamable<Person> {
509+
510+
private final Streamable<Person> streamable;
511+
512+
public Persons(Streamable<Person> streamable) {
513+
this.streamable = streamable;
514+
}
515+
516+
@Override
517+
public Iterator<Person> iterator() {
518+
return streamable.iterator();
519+
}
520+
}
505521
}

0 commit comments

Comments
 (0)