@@ -44,12 +44,13 @@ public static AstPipeline Optimize(AstPipeline pipeline)
4444 #endregion
4545
4646 private readonly AccumulatorSet _accumulators = new AccumulatorSet ( ) ;
47+ private AstExpression _element ; // normally either "$$ROOT" or "$_v"
4748
4849 private AstPipeline OptimizeGroupStage ( AstPipeline pipeline , int i , AstGroupStage groupStage )
4950 {
5051 try
5152 {
52- if ( IsOptimizableGroupStage ( groupStage ) )
53+ if ( IsOptimizableGroupStage ( groupStage , out _element ) )
5354 {
5455 var followingStages = GetFollowingStagesToOptimize ( pipeline , i + 1 ) ;
5556 if ( followingStages == null )
@@ -71,22 +72,22 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag
7172
7273 return pipeline ;
7374
74- static bool IsOptimizableGroupStage ( AstGroupStage groupStage )
75+ static bool IsOptimizableGroupStage ( AstGroupStage groupStage , out AstExpression element )
7576 {
76- // { $group : { _id : ?, _elements : { $push : "$$ROOT" } } }
77+ // { $group : { _id : ?, _elements : { $push : element } } }
7778 if ( groupStage . Fields . Count == 1 )
7879 {
7980 var field = groupStage . Fields [ 0 ] ;
8081 if ( field . Path == "_elements" &&
8182 field . Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
82- unaryAccumulatorExpression . Operator == AstUnaryAccumulatorOperator . Push &&
83- unaryAccumulatorExpression . Arg is AstVarExpression varExpression &&
84- varExpression . Name == "ROOT" )
83+ unaryAccumulatorExpression . Operator == AstUnaryAccumulatorOperator . Push )
8584 {
85+ element = unaryAccumulatorExpression . Arg ;
8686 return true ;
8787 }
8888 }
8989
90+ element = null ;
9091 return false ;
9192 }
9293
@@ -173,7 +174,7 @@ private AstStage OptimizeLimitStage(AstLimitStage stage)
173174
174175 private AstStage OptimizeMatchStage ( AstMatchStage stage )
175176 {
176- var optimizedFilter = AccumulatorMover . MoveAccumulators ( _accumulators , stage . Filter ) ;
177+ var optimizedFilter = AccumulatorMover . MoveAccumulators ( _accumulators , _element , stage . Filter ) ;
177178 return stage . Update ( optimizedFilter ) ;
178179 }
179180
@@ -201,7 +202,7 @@ private AstProjectStageSpecification OptimizeProjectStageSpecification(AstProjec
201202
202203 private AstProjectStageSpecification OptimizeProjectStageSetFieldSpecification ( AstProjectStageSetFieldSpecification specification )
203204 {
204- var optimizedValue = AccumulatorMover . MoveAccumulators ( _accumulators , specification . Value ) ;
205+ var optimizedValue = AccumulatorMover . MoveAccumulators ( _accumulators , _element , specification . Value ) ;
205206 return specification . Update ( optimizedValue ) ;
206207 }
207208
@@ -249,27 +250,29 @@ public string AddAccumulatorExpression(AstAccumulatorExpression value)
249250 private class AccumulatorMover : AstNodeVisitor
250251 {
251252 #region static
252- public static TNode MoveAccumulators < TNode > ( AccumulatorSet accumulators , TNode node )
253+ public static TNode MoveAccumulators < TNode > ( AccumulatorSet accumulators , AstExpression element , TNode node )
253254 where TNode : AstNode
254255 {
255- var mover = new AccumulatorMover ( accumulators ) ;
256+ var mover = new AccumulatorMover ( accumulators , element ) ;
256257 return mover . VisitAndConvert ( node ) ;
257258 }
258259 #endregion
259260
260261 private readonly AccumulatorSet _accumulators ;
262+ private readonly AstExpression _element ;
261263
262- private AccumulatorMover ( AccumulatorSet accumulator )
264+ private AccumulatorMover ( AccumulatorSet accumulator , AstExpression element )
263265 {
264266 _accumulators = accumulator ;
267+ _element = element ;
265268 }
266269
267270 public override AstNode VisitFilterField ( AstFilterField node )
268271 {
269- // "_elements.0.X" => { __agg0 : { $first : "$$ROOT" } } + "__agg0.X"
272+ // "_elements.0.X" => { __agg0 : { $first : element } } + "__agg0.X"
270273 if ( node . Path . StartsWith ( "_elements.0." ) )
271274 {
272- var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . First , AstExpression . Var ( "ROOT" ) ) ;
275+ var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . First , _element ) ;
273276 var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
274277 var restOfPath = node . Path . Substring ( "_elements.0." . Length ) ;
275278 var rewrittenPath = $ "{ accumulatorFieldName } .{ restOfPath } ";
@@ -288,9 +291,7 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
288291 {
289292 if ( node . FieldName is AstConstantExpression constantFieldName &&
290293 constantFieldName . Value . IsString &&
291- constantFieldName . Value . AsString == "_elements" &&
292- node . Input is AstVarExpression varExpression &&
293- varExpression . Name == "ROOT" )
294+ constantFieldName . Value . AsString == "_elements" )
294295 {
295296 throw new UnableToRemoveReferenceToElementsException ( ) ;
296297 }
@@ -300,18 +301,18 @@ node.Input is AstVarExpression varExpression &&
300301
301302 public override AstNode VisitMapExpression ( AstMapExpression node )
302303 {
303- // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => root ) } } + "$__agg0"
304+ // { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element ) } } + "$__agg0"
304305 if ( node . Input is AstGetFieldExpression mapInputGetFieldExpression &&
305306 mapInputGetFieldExpression . FieldName is AstConstantExpression mapInputconstantFieldExpression &&
306307 mapInputconstantFieldExpression . Value . IsString &&
307308 mapInputconstantFieldExpression . Value . AsString == "_elements" &&
308309 mapInputGetFieldExpression . Input is AstVarExpression mapInputGetFieldVarExpression &&
309310 mapInputGetFieldVarExpression . Name == "ROOT" )
310311 {
311- var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
312- var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( node . In , ( node . As , root ) ) ;
312+ var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( node . In , ( node . As , _element ) ) ;
313313 var accumulatorExpression = AstExpression . UnaryAccumulator ( AstUnaryAccumulatorOperator . Push , rewrittenArg ) ;
314314 var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
315+ var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
315316 return AstExpression . GetField ( root , accumulatorFieldName ) ;
316317 }
317318
@@ -321,7 +322,7 @@ mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpressi
321322 public override AstNode VisitPickExpression ( AstPickExpression node )
322323 {
323324 // { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
324- // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => root ) } } } + "$__agg0"
325+ // => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element ) } } } + "$__agg0"
325326 if ( node . Source is AstGetFieldExpression getFieldExpression &&
326327 getFieldExpression . Input is AstVarExpression varExpression &&
327328 varExpression . Name == "ROOT" &&
@@ -330,10 +331,10 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
330331 constantFieldNameExpression . Value . AsString == "_elements" )
331332 {
332333 var @operator = node . Operator . ToAccumulatorOperator ( ) ;
333- var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
334- var rewrittenSelector = ( AstExpression ) AstNodeReplacer . Replace ( node . Selector , ( node . As , root ) ) ;
334+ var rewrittenSelector = ( AstExpression ) AstNodeReplacer . Replace ( node . Selector , ( node . As , _element ) ) ;
335335 var accumulatorExpression = new AstPickAccumulatorExpression ( @operator , node . SortBy , rewrittenSelector , node . N ) ;
336336 var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
337+ var root = AstExpression . Var ( "ROOT" , isCurrent : true ) ;
337338 return AstExpression . GetField ( root , accumulatorFieldName ) ;
338339 }
339340
@@ -384,7 +385,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres
384385
385386 bool TryOptimizeAccumulatorOfElements ( out AstExpression optimizedExpression )
386387 {
387- // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : "$$ROOT" } } + "$__agg0"
388+ // { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
388389 if ( node . Operator . IsAccumulator ( out var accumulatorOperator ) &&
389390 node . Arg is AstGetFieldExpression getFieldExpression &&
390391 getFieldExpression . FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
@@ -393,7 +394,7 @@ getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameE
393394 getFieldExpression . Input is AstVarExpression getFieldInputVarExpression &&
394395 getFieldInputVarExpression . Name == "ROOT" )
395396 {
396- var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , root ) ;
397+ var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , _element ) ;
397398 var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
398399 optimizedExpression = AstExpression . GetField ( root , accumulatorFieldName ) ;
399400 return true ;
@@ -406,7 +407,7 @@ getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
406407
407408 bool TryOptimizeAccumulatorOfMappedElements ( out AstExpression optimizedExpression )
408409 {
409- // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => root ) } } + "$__agg0"
410+ // { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element ) } } + "$__agg0"
410411 if ( node . Operator . IsAccumulator ( out var accumulatorOperator ) &&
411412 node . Arg is AstMapExpression mapExpression &&
412413 mapExpression . Input is AstGetFieldExpression mapInputGetFieldExpression &&
@@ -416,7 +417,7 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi
416417 mapInputGetFieldExpression . Input is AstVarExpression mapInputGetFieldVarExpression &&
417418 mapInputGetFieldVarExpression . Name == "ROOT" )
418419 {
419- var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( mapExpression . In , ( mapExpression . As , root ) ) ;
420+ var rewrittenArg = ( AstExpression ) AstNodeReplacer . Replace ( mapExpression . In , ( mapExpression . As , _element ) ) ;
420421 var accumulatorExpression = AstExpression . UnaryAccumulator ( accumulatorOperator , rewrittenArg ) ;
421422 var accumulatorFieldName = _accumulators . AddAccumulatorExpression ( accumulatorExpression ) ;
422423 optimizedExpression = AstExpression . GetField ( root , accumulatorFieldName ) ;
0 commit comments