1919import java .util .List ;
2020import java .util .Optional ;
2121import java .util .regex .Pattern ;
22+ import java .util .stream .Stream ;
2223
2324import org .bson .Document ;
2425import org .jspecify .annotations .NullUnmarked ;
4950import org .springframework .data .mongodb .repository .query .MongoQueryMethod ;
5051import org .springframework .data .repository .aot .generate .AotQueryMethodGenerationContext ;
5152import org .springframework .data .util .ReflectionUtils ;
52- import org .springframework .javapoet .ClassName ;
5353import org .springframework .javapoet .CodeBlock ;
5454import org .springframework .javapoet .CodeBlock .Builder ;
5555import org .springframework .javapoet .TypeName ;
@@ -182,17 +182,15 @@ CodeBlock build() {
182182 String mongoOpsRef = context .fieldNameOf (MongoOperations .class );
183183 Builder builder = CodeBlock .builder ();
184184
185+ Class <?> domainType = context .getRepositoryInformation ().getDomainType ();
185186 boolean isProjecting = context .getActualReturnType () != null
186- && !ObjectUtils .nullSafeEquals (TypeName .get (context .getRepositoryInformation ().getDomainType ()),
187- context .getActualReturnType ());
187+ && !ObjectUtils .nullSafeEquals (TypeName .get (domainType ), context .getActualReturnType ());
188188
189- Object actualReturnType = isProjecting ? context .getActualReturnType ().getType ()
190- : context .getRepositoryInformation ().getDomainType ();
189+ Object actualReturnType = isProjecting ? context .getActualReturnType ().getType () : domainType ;
191190
192191 builder .add ("\n " );
193- builder .addStatement ("$T<$T> remover = $L.remove($T.class)" , ExecutableRemove .class ,
194- context .getRepositoryInformation ().getDomainType (), mongoOpsRef ,
195- context .getRepositoryInformation ().getDomainType ());
192+ builder .addStatement ("$T<$T> $L = $L.remove($T.class)" , ExecutableRemove .class , domainType ,
193+ context .localVariable ("remover" ), mongoOpsRef , domainType );
196194
197195 DeleteExecution .Type type = DeleteExecution .Type .FIND_AND_REMOVE_ALL ;
198196 if (!queryMethod .isCollectionQuery ()) {
@@ -204,11 +202,20 @@ CodeBlock build() {
204202 }
205203
206204 actualReturnType = ClassUtils .isPrimitiveOrWrapper (context .getMethod ().getReturnType ())
207- ? ClassName .get (context .getMethod ().getReturnType ())
205+ ? TypeName .get (context .getMethod ().getReturnType ())
208206 : queryMethod .isCollectionQuery () ? context .getReturnTypeName () : actualReturnType ;
209207
210- builder .addStatement ("return ($T) new $T(remover, $T.$L).execute($L)" , actualReturnType , DeleteExecution .class ,
211- DeleteExecution .Type .class , type .name (), queryVariableName );
208+ if (ClassUtils .isVoidType (context .getMethod ().getReturnType ())) {
209+ builder .addStatement ("new $T($L, $T.$L).execute($L)" , DeleteExecution .class , context .localVariable ("remover" ),
210+ DeleteExecution .Type .class , type .name (), queryVariableName );
211+ } else if (context .getMethod ().getReturnType () == Optional .class ) {
212+ builder .addStatement ("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))" , Optional .class ,
213+ actualReturnType , DeleteExecution .class , context .localVariable ("remover" ), DeleteExecution .Type .class ,
214+ type .name (), queryVariableName );
215+ } else {
216+ builder .addStatement ("return ($T) new $T($L, $T.$L).execute($L)" , actualReturnType , DeleteExecution .class ,
217+ context .localVariable ("remover" ), DeleteExecution .Type .class , type .name (), queryVariableName );
218+ }
212219
213220 return builder .build ();
214221 }
@@ -318,14 +325,25 @@ CodeBlock build() {
318325
319326 Class <?> returnType = ClassUtils .resolvePrimitiveIfNecessary (queryMethod .getReturnedObjectType ());
320327
321- builder .addStatement ("$T $L = $L.aggregate($L, $T.class)" , AggregationResults .class ,
322- context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
323- if (!queryMethod .isCollectionQuery ()) {
324- builder .addStatement ("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
325- CollectionUtils .class , returnType , returnType , context .localVariable ("results" ));
328+ if (queryMethod .isStreamQuery ()) {
329+
330+ builder .addStatement ("$T<$T> $L = $L.aggregateStream($L, $T.class)" , Stream .class , Document .class ,
331+ context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
332+
333+ builder .addStatement ("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))" ,
334+ context .localVariable ("results" ), returnType , returnType );
326335 } else {
327- builder .addStatement ("return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
328- context .localVariable ("results" ));
336+
337+ builder .addStatement ("$T $L = $L.aggregate($L, $T.class)" , AggregationResults .class ,
338+ context .localVariable ("results" ), mongoOpsRef , aggregationVariableName , outputType );
339+
340+ if (!queryMethod .isCollectionQuery ()) {
341+ builder .addStatement ("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))" ,
342+ CollectionUtils .class , returnType , returnType , context .localVariable ("results" ));
343+ } else {
344+ builder .addStatement ("return convertSimpleRawResults($T.class, $L.getMappedResults())" , returnType ,
345+ context .localVariable ("results" ));
346+ }
329347 }
330348 } else {
331349 if (queryMethod .isSliceQuery ()) {
@@ -339,8 +357,15 @@ CodeBlock build() {
339357 context .getPageableParameterName (), context .localVariable ("results" ), context .getPageableParameterName (),
340358 context .localVariable ("hasNext" ));
341359 } else {
342- builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
343- aggregationVariableName , outputType );
360+
361+ if (queryMethod .isStreamQuery ()) {
362+ builder .addStatement ("return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
363+ outputType );
364+ } else {
365+
366+ builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
367+ aggregationVariableName , outputType );
368+ }
344369 }
345370 }
346371
@@ -420,8 +445,16 @@ CodeBlock build() {
420445 builder .addStatement ("return $L.matching($L).scroll($L)" , context .localVariable ("finder" ), query .name (),
421446 scrollPositionParameterName );
422447 } else {
423- builder .addStatement ("return $L.matching($L).$L" , context .localVariable ("finder" ), query .name (),
424- terminatingMethod );
448+ if (query .isCount () && !ClassUtils .isAssignable (Long .class , context .getActualReturnType ().getRawClass ())) {
449+
450+ Class <?> returnType = ClassUtils .resolvePrimitiveIfNecessary (queryMethod .getReturnedObjectType ());
451+ builder .addStatement ("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)" , NumberUtils .class ,
452+ context .localVariable ("finder" ), query .name (), terminatingMethod , returnType );
453+
454+ } else {
455+ builder .addStatement ("return $L.matching($L).$L" , context .localVariable ("finder" ), query .name (),
456+ terminatingMethod );
457+ }
425458 }
426459
427460 return builder .build ();
0 commit comments