@@ -74,15 +74,17 @@ class DBCursorFunctionalSpecification extends FunctionalSpecification {
7474
7575 when :
7676 dbCursor = collection. find(). hint(new BasicDBObject (' a' , 1 ))
77+ def explainPlan = dbCursor. explain()
7778
7879 then :
79- dbCursor . explain() . queryPlanner . winningPlan . inputStage . keyPattern == cursorMap
80+ getKeyPattern(explainPlan) == cursorMap
8081
8182 when :
8283 dbCursor = collection. find(). addSpecial(' $hint' , new BasicDBObject (' a' , 1 ))
84+ explainPlan = dbCursor. explain()
8385
8486 then :
85- dbCursor . explain() . queryPlanner . winningPlan . inputStage . keyPattern == cursorMap
87+ getKeyPattern(explainPlan) == cursorMap
8688 }
8789
8890 def ' should use provided hint for count' () {
@@ -118,15 +120,17 @@ class DBCursorFunctionalSpecification extends FunctionalSpecification {
118120
119121 when :
120122 dbCursor = collection. find(). hint(' a_1' )
123+ def explainPlan = dbCursor. explain()
121124
122125 then :
123- dbCursor . explain() . queryPlanner . winningPlan . inputStage . keyPattern == cursorMap
126+ getKeyPattern(explainPlan) == cursorMap
124127
125128 when :
126129 dbCursor = collection. find(). addSpecial(' $hint' , ' a_1' )
130+ explainPlan = dbCursor. explain()
127131
128132 then :
129- dbCursor . explain() . queryPlanner . winningPlan . inputStage . keyPattern == cursorMap
133+ getKeyPattern(explainPlan) == cursorMap
130134 }
131135
132136 def ' should use provided hints for count' () {
@@ -209,9 +213,10 @@ class DBCursorFunctionalSpecification extends FunctionalSpecification {
209213 when :
210214 dbCursor = collection. find(). hint(new BasicDBObject (' a' , 1 ))
211215 dbCursor. addSpecial(' $explain' , 1 )
216+ def explainPlan = dbCursor. explain()
212217
213218 then :
214- dbCursor . explain() . queryPlanner . winningPlan . inputStage . keyPattern == cursorMap
219+ getKeyPattern(explainPlan) == cursorMap
215220 }
216221
217222
@@ -372,4 +377,12 @@ class DBCursorFunctionalSpecification extends FunctionalSpecification {
372377 then :
373378 executor. getReadPreference() == ReadPreference . secondaryPreferred()
374379 }
380+
381+ static DBObject getKeyPattern (DBObject explainPlan ) {
382+ if (explainPlan. queryPlanner. winningPlan. inputStage != null ) {
383+ return explainPlan. queryPlanner. winningPlan. inputStage. keyPattern
384+ } else if (explainPlan. queryPlanner. winningPlan. shards != null ) {
385+ return explainPlan. queryPlanner. winningPlan. shards[0 ]. winningPlan. inputStage. keyPattern
386+ }
387+ }
375388}
0 commit comments