2222
2323import org .bson .Document ;
2424import org .jspecify .annotations .NullUnmarked ;
25-
2625import org .springframework .core .ResolvableType ;
2726import org .springframework .core .annotation .MergedAnnotation ;
2827import org .springframework .data .domain .SliceImpl ;
2928import org .springframework .data .domain .Sort .Order ;
3029import org .springframework .data .mongodb .core .MongoOperations ;
3130import org .springframework .data .mongodb .core .aggregation .Aggregation ;
31+ import org .springframework .data .mongodb .core .aggregation .AggregationOperation ;
3232import org .springframework .data .mongodb .core .aggregation .AggregationOptions ;
3333import org .springframework .data .mongodb .core .aggregation .AggregationPipeline ;
3434import org .springframework .data .mongodb .core .aggregation .AggregationResults ;
@@ -80,12 +80,7 @@ CodeBlock build() {
8080
8181 builder .add ("\n " );
8282
83- Class <?> outputType = queryMethod .getReturnedObjectType ();
84- if (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
85- outputType = Document .class ;
86- } else if (ClassUtils .isAssignable (AggregationResults .class , outputType )) {
87- outputType = queryMethod .getReturnType ().getComponentType ().getType ();
88- }
83+ Class <?> outputType = getOutputType (queryMethod );
8984
9085 if (ReflectionUtils .isVoid (queryMethod .getReturnedObjectType ())) {
9186 builder .addStatement ("$L.aggregate($L, $T.class)" , mongoOpsRef , aggregationVariableName , outputType );
@@ -146,7 +141,6 @@ CodeBlock build() {
146141 builder .addStatement ("return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
147142 outputType );
148143 } else {
149-
150144 builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
151145 aggregationVariableName , outputType );
152146 }
@@ -155,6 +149,17 @@ CodeBlock build() {
155149
156150 return builder .build ();
157151 }
152+
153+ }
154+
155+ private static Class <?> getOutputType (MongoQueryMethod queryMethod ) {
156+ Class <?> outputType = queryMethod .getReturnedObjectType ();
157+ if (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
158+ outputType = Document .class ;
159+ } else if (ClassUtils .isAssignable (AggregationResults .class , outputType ) && queryMethod .getReturnType ().getComponentType () != null ) {
160+ outputType = queryMethod .getReturnType ().getComponentType ().getType ();
161+ }
162+ return outputType ;
158163 }
159164
160165 @ NullUnmarked
@@ -173,13 +178,7 @@ static class AggregationCodeBlockBuilder {
173178
174179 this .context = context ;
175180 this .queryMethod = queryMethod ;
176- String parameterNames = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
177-
178- if (StringUtils .hasText (parameterNames )) {
179- this .parameterNames = ", " + parameterNames ;
180- } else {
181- this .parameterNames = "" ;
182- }
181+ this .parameterNames = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
183182 }
184183
185184 AggregationCodeBlockBuilder stages (AggregationInteraction aggregation ) {
@@ -231,7 +230,8 @@ private CodeBlock pipeline(String pipelineVariableName) {
231230 builder .add (aggregationStages (context .localVariable ("stages" ), source .stages ()));
232231
233232 if (StringUtils .hasText (sortParameter )) {
234- builder .add (sortingStage (sortParameter ));
233+ Class <?> outputType = getOutputType (queryMethod );
234+ builder .add (sortingStage (sortParameter , outputType ));
235235 }
236236
237237 if (StringUtils .hasText (limitParameter )) {
@@ -244,6 +244,7 @@ private CodeBlock pipeline(String pipelineVariableName) {
244244
245245 builder .addStatement ("$T $L = createPipeline($L)" , AggregationPipeline .class , pipelineVariableName ,
246246 context .localVariable ("stages" ));
247+
247248 return builder .build ();
248249 }
249250
@@ -312,7 +313,7 @@ private CodeBlock aggregationStages(String stageListVariableName, Collection<Str
312313 return builder .build ();
313314 }
314315
315- private CodeBlock sortingStage (String sortProvider ) {
316+ private CodeBlock sortingStage (String sortProvider , Class <?> outputType ) {
316317
317318 Builder builder = CodeBlock .builder ();
318319
@@ -322,8 +323,17 @@ private CodeBlock sortingStage(String sortProvider) {
322323 builder .addStatement ("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);" ,
323324 context .localVariable ("sortDocument" ), context .localVariable ("order" ));
324325 builder .endControlFlow ();
325- builder .addStatement ("stages.add(new $T($S, $L))" , Document .class , "$sort" ,
326- context .localVariable ("sortDocument" ));
326+
327+ if (outputType == Document .class || MongoSimpleTypes .HOLDER .isSimpleType (outputType )
328+ || ClassUtils .isAssignable (context .getRepositoryInformation ().getDomainType (), outputType )) {
329+ builder .addStatement ("$L.add(new $T($S, $L))" , context .localVariable ("stages" ), Document .class , "$sort" ,
330+ context .localVariable ("sortDocument" ));
331+ } else {
332+ builder .addStatement ("$L.add(($T) _ctx -> new $T($S, _ctx.getMappedObject($L, $T.class)))" ,
333+ context .localVariable ("stages" ), AggregationOperation .class , Document .class , "$sort" ,
334+ context .localVariable ("sortDocument" ), outputType );
335+ }
336+
327337 builder .endControlFlow ();
328338
329339 return builder .build ();
@@ -333,7 +343,7 @@ private CodeBlock pagingStage(String pageableProvider, boolean slice) {
333343
334344 Builder builder = CodeBlock .builder ();
335345
336- builder .add (sortingStage (pageableProvider + ".getSort()" ));
346+ builder .add (sortingStage (pageableProvider + ".getSort()" , getOutputType ( queryMethod ) ));
337347
338348 builder .beginControlFlow ("if ($L.isPaged())" , pageableProvider );
339349 builder .beginControlFlow ("if ($L.getOffset() > 0)" , pageableProvider );
0 commit comments