1616package org .springframework .data .mongodb .core ;
1717
1818import java .util .ArrayList ;
19- import java .util .Collections ;
2019import java .util .List ;
2120import java .util .Optional ;
2221import java .util .stream .Collectors ;
2524import org .bson .conversions .Bson ;
2625import org .springframework .context .ApplicationEventPublisher ;
2726import org .springframework .dao .DataIntegrityViolationException ;
27+ import org .springframework .data .mapping .PersistentEntity ;
2828import org .springframework .data .mapping .callback .EntityCallbacks ;
2929import org .springframework .data .mongodb .BulkOperationException ;
30+ import org .springframework .data .mongodb .core .aggregation .AggregationOperationContext ;
31+ import org .springframework .data .mongodb .core .aggregation .AggregationUpdate ;
32+ import org .springframework .data .mongodb .core .aggregation .RelaxedTypeBasedAggregationOperationContext ;
3033import org .springframework .data .mongodb .core .convert .QueryMapper ;
3134import org .springframework .data .mongodb .core .convert .UpdateMapper ;
3235import org .springframework .data .mongodb .core .mapping .MongoPersistentEntity ;
@@ -133,12 +136,12 @@ public BulkOperations insert(List<? extends Object> documents) {
133136
134137 @ Override
135138 @ SuppressWarnings ("unchecked" )
136- public BulkOperations updateOne (Query query , Update update ) {
139+ public BulkOperations updateOne (Query query , UpdateDefinition update ) {
137140
138141 Assert .notNull (query , "Query must not be null" );
139142 Assert .notNull (update , "Update must not be null" );
140143
141- return updateOne ( Collections . singletonList ( Pair . of ( query , update )) );
144+ return update ( query , update , false , false );
142145 }
143146
144147 @ Override
@@ -155,12 +158,14 @@ public BulkOperations updateOne(List<Pair<Query, Update>> updates) {
155158
156159 @ Override
157160 @ SuppressWarnings ("unchecked" )
158- public BulkOperations updateMulti (Query query , Update update ) {
161+ public BulkOperations updateMulti (Query query , UpdateDefinition update ) {
159162
160163 Assert .notNull (query , "Query must not be null" );
161164 Assert .notNull (update , "Update must not be null" );
162165
163- return updateMulti (Collections .singletonList (Pair .of (query , update )));
166+ update (query , update , false , true );
167+
168+ return this ;
164169 }
165170
166171 @ Override
@@ -176,7 +181,7 @@ public BulkOperations updateMulti(List<Pair<Query, Update>> updates) {
176181 }
177182
178183 @ Override
179- public BulkOperations upsert (Query query , Update update ) {
184+ public BulkOperations upsert (Query query , UpdateDefinition update ) {
180185 return update (query , update , true , true );
181186 }
182187
@@ -294,7 +299,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
294299 maybeInvokeBeforeSaveCallback (it .getSource (), target );
295300 }
296301
297- return mapWriteModel (it .getModel ());
302+ return mapWriteModel (it .getSource (), it . getModel ());
298303 }
299304
300305 /**
@@ -306,7 +311,7 @@ private WriteModel<Document> extractAndMapWriteModel(SourceAwareWriteModelHolder
306311 * @param multi whether to issue a multi-update.
307312 * @return the {@link BulkOperations} with the update registered.
308313 */
309- private BulkOperations update (Query query , Update update , boolean upsert , boolean multi ) {
314+ private BulkOperations update (Query query , UpdateDefinition update , boolean upsert , boolean multi ) {
310315
311316 Assert .notNull (query , "Query must not be null" );
312317 Assert .notNull (update , "Update must not be null" );
@@ -322,11 +327,16 @@ private BulkOperations update(Query query, Update update, boolean upsert, boolea
322327 return this ;
323328 }
324329
325- private WriteModel <Document > mapWriteModel (WriteModel <Document > writeModel ) {
330+ private WriteModel <Document > mapWriteModel (Object source , WriteModel <Document > writeModel ) {
326331
327332 if (writeModel instanceof UpdateOneModel ) {
328333
329334 UpdateOneModel <Document > model = (UpdateOneModel <Document >) writeModel ;
335+ if (source instanceof AggregationUpdate aggregationUpdate ) {
336+
337+ List <Document > pipeline = mapUpdatePipeline (aggregationUpdate );
338+ return new UpdateOneModel <>(getMappedQuery (model .getFilter ()), pipeline , model .getOptions ());
339+ }
330340
331341 return new UpdateOneModel <>(getMappedQuery (model .getFilter ()), getMappedUpdate (model .getUpdate ()),
332342 model .getOptions ());
@@ -335,6 +345,11 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
335345 if (writeModel instanceof UpdateManyModel ) {
336346
337347 UpdateManyModel <Document > model = (UpdateManyModel <Document >) writeModel ;
348+ if (source instanceof AggregationUpdate aggregationUpdate ) {
349+
350+ List <Document > pipeline = mapUpdatePipeline (aggregationUpdate );
351+ return new UpdateManyModel <>(getMappedQuery (model .getFilter ()), pipeline , model .getOptions ());
352+ }
338353
339354 return new UpdateManyModel <>(getMappedQuery (model .getFilter ()), getMappedUpdate (model .getUpdate ()),
340355 model .getOptions ());
@@ -357,6 +372,19 @@ private WriteModel<Document> mapWriteModel(WriteModel<Document> writeModel) {
357372 return writeModel ;
358373 }
359374
375+ private List <Document > mapUpdatePipeline (AggregationUpdate source ) {
376+ Class <?> type = bulkOperationContext .getEntity ().isPresent ()
377+ ? bulkOperationContext .getEntity ().map (PersistentEntity ::getType ).get ()
378+ : Object .class ;
379+ AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext (type ,
380+ bulkOperationContext .getUpdateMapper ().getMappingContext (), bulkOperationContext .getQueryMapper ());
381+
382+ List <Document > pipeline = new AggregationUtil (bulkOperationContext .getQueryMapper (),
383+ bulkOperationContext .getQueryMapper ().getMappingContext ()).createPipeline (source ,
384+ context );
385+ return pipeline ;
386+ }
387+
360388 private Bson getMappedUpdate (Bson update ) {
361389 return bulkOperationContext .getUpdateMapper ().getMappedObject (update , bulkOperationContext .getEntity ());
362390 }
0 commit comments