@@ -270,21 +270,24 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270270 return
271271 }
272272
273+ startedEvents := make ([]* cmdStartedEvt , len (* expectations ))
274+ succeededEvents := make ([]* cmdSucceededEvt , len (* expectations ))
275+ failedEvents := make ([]* cmdFailedEvt , len (* expectations ))
276+
273277 for idx , expectation := range * expectations {
274- var err error
278+ startedEvents [idx ] = expectation .CommandStartedEvent
279+ succeededEvents [idx ] = expectation .CommandSucceededEvent
280+ failedEvents [idx ] = expectation .CommandFailedEvent
281+ }
275282
276- if expectation .CommandStartedEvent != nil {
277- err = compareStartedEvent (mt , expectation , id0 , id1 )
278- }
279- if expectation .CommandSucceededEvent != nil {
280- err = compareSucceededEvent (mt , expectation )
281- }
282- if expectation .CommandFailedEvent != nil {
283- err = compareFailedEvent (mt , expectation )
284- }
283+ var err error
284+ err = compareStartedEvents (mt , startedEvents , id0 , id1 )
285+ assert .Nil (mt , err , "expectation comparison %s" , err )
286+ err = compareSucceededEvents (mt , succeededEvents )
287+ assert .Nil (mt , err , "expectation comparison %s" , err )
288+ err = compareFailedEvents (mt , failedEvents )
289+ assert .Nil (mt , err , "expectation comparison %s" , err )
285290
286- assert .Nil (mt , err , "expectation comparison error at index %v: %s" , idx , err )
287- }
288291}
289292
290293// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +301,105 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298301 return fmt .Errorf ("%s\n Expected %s\n Got: %s" , msg , string (expectedJSON ), string (actualJSON ))
299302}
300303
301- func compareStartedEvent (mt * mtest.T , expectation * expectation , id0 , id1 bson.Raw ) error {
304+ func compareStartedEvents (mt * mtest.T , expectations [] * cmdStartedEvt , id0 , id1 bson.Raw ) error {
302305 mt .Helper ()
303306
304- expected := expectation .CommandStartedEvent
305-
306- if len (expected .Extra ) > 0 {
307- return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
308- }
309-
310- evt := mt .GetStartedEvent ()
311- if evt == nil {
312- return errors .New ("expected CommandStartedEvent, got nil" )
313- }
314-
315- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
316- return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
317- }
318- if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
319- return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
307+ expectedCmds := make (map [string ]bool )
308+ for _ , expected := range expectations {
309+ expectedCmds [expected .CommandName ] = true
320310 }
321311
322- eElems , err := expected .Command .Elements ()
323- if err != nil {
324- return fmt .Errorf ("error getting expected command elements: %s" , err )
325- }
326-
327- for _ , elem := range eElems {
328- key := elem .Key ()
329- val := elem .Value ()
330-
331- actualVal , err := evt .Command .LookupErr (key )
312+ compare := func (expected * cmdStartedEvt ) error {
313+ if expected == nil {
314+ return nil
315+ }
316+ if len (expected .Extra ) > 0 {
317+ return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
318+ }
332319
333- // Keys that may be nil
334- if val .Type == bson .TypeNull {
335- // Expected value is BSON null. Expect the actual field to be omitted.
336- if errors .Is (err , bsoncore .ErrElementNotFound ) {
337- continue
320+ var evt * event.CommandStartedEvent
321+ // skip events not in expectations
322+ for {
323+ evt = mt .GetStartedEvent ()
324+ if evt == nil {
325+ return errors .New ("expected CommandStartedEvent, got nil" )
338326 }
339- if err != nil {
340- return newMatchError ( mt , expected . Command , evt . Command , "expected key %q to be omitted but got error: %v" , key , err )
327+ if v , ok := expectedCmds [ expected . CommandName ]; ok && v {
328+ break
341329 }
342- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
343330 }
344- assert .Nil (mt , err , "expected command to contain key %q" , key )
345331
346- if key == "batchSize" {
347- // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348- // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349- // versions do not support the limit option, but not for 3.2+. We've already validated that the command
350- // contains a batchSize field above and we can skip the actual value comparison below.
351- continue
332+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
333+ return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
334+ }
335+ if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
336+ return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
337+ }
338+
339+ eElems , err := expected .Command .Elements ()
340+ if err != nil {
341+ return fmt .Errorf ("error getting expected command elements: %s" , err )
352342 }
353343
354- switch key {
355- case "lsid" :
356- sessName := val .StringValue ()
357- var expectedID bson.Raw
358- actualID := actualVal .Document ()
344+ for _ , elem := range eElems {
345+ key := elem .Key ()
346+ val := elem .Value ()
359347
360- switch sessName {
361- case "session0" :
362- expectedID = id0
363- case "session1" :
364- expectedID = id1
365- default :
366- return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
348+ actualVal , err := evt .Command .LookupErr (key )
349+
350+ // Keys that may be nil
351+ if val .Type == bson .TypeNull {
352+ // Expected value is BSON null. Expect the actual field to be omitted.
353+ if errors .Is (err , bsoncore .ErrElementNotFound ) {
354+ continue
355+ }
356+ if err != nil {
357+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
358+ }
359+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
367360 }
361+ assert .Nil (mt , err , "expected command to contain key %q" , key )
368362
369- if ! bytes .Equal (expectedID , actualID ) {
370- return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
371- actualID )
363+ if key == "batchSize" {
364+ // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
365+ // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
366+ // versions do not support the limit option, but not for 3.2+. We've already validated that the command
367+ // contains a batchSize field above and we can skip the actual value comparison below.
368+ continue
372369 }
373- default :
374- if err := compareValues (mt , key , val , actualVal ); err != nil {
375- return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
370+
371+ switch key {
372+ case "lsid" :
373+ sessName := val .StringValue ()
374+ var expectedID bson.Raw
375+ actualID := actualVal .Document ()
376+
377+ switch sessName {
378+ case "session0" :
379+ expectedID = id0
380+ case "session1" :
381+ expectedID = id1
382+ default :
383+ return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
384+ }
385+
386+ if ! bytes .Equal (expectedID , actualID ) {
387+ return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
388+ actualID )
389+ }
390+ default :
391+ if err := compareValues (mt , key , val , actualVal ); err != nil {
392+ return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
393+ }
376394 }
377395 }
396+ return nil
397+ }
398+ for idx , expected := range expectations {
399+ err := compare (expected )
400+ if err != nil {
401+ return fmt .Errorf ("error at index %d: %s" , idx , err )
402+ }
378403 }
379404 return nil
380405}
@@ -416,60 +441,108 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416441 return nil
417442}
418443
419- func compareSucceededEvent (mt * mtest.T , expectation * expectation ) error {
444+ func compareSucceededEvents (mt * mtest.T , expectations [] * cmdSucceededEvt ) error {
420445 mt .Helper ()
421446
422- expected := expectation .CommandSucceededEvent
423- if len (expected .Extra ) > 0 {
424- return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
425- }
426- evt := mt .GetSucceededEvent ()
427- if evt == nil {
428- return errors .New ("expected CommandSucceededEvent, got nil" )
447+ expectedCmds := make (map [string ]bool )
448+ for _ , expected := range expectations {
449+ expectedCmds [expected .CommandName ] = true
429450 }
430451
431- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
432- return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
433- }
452+ compare := func (expected * cmdSucceededEvt ) error {
453+ if expected == nil {
454+ return nil
455+ }
456+ if len (expected .Extra ) > 0 {
457+ return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
458+ }
434459
435- eElems , err := expected .Reply .Elements ()
436- if err != nil {
437- return fmt .Errorf ("error getting expected reply elements: %s" , err )
438- }
460+ var evt * event.CommandSucceededEvent
461+ // skip events not in expectations
462+ for {
463+ evt = mt .GetSucceededEvent ()
464+ if evt == nil {
465+ return errors .New ("expected CommandSucceededEvent, got nil" )
466+ }
467+ if v , ok := expectedCmds [expected .CommandName ]; ok && v {
468+ break
469+ }
470+ }
439471
440- for _ , elem := range eElems {
441- key := elem .Key ()
442- val := elem .Value ()
443- actualVal := evt .Reply .Lookup (key )
472+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
473+ return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
474+ }
444475
445- switch key {
446- case "writeErrors" :
447- if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
448- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
449- }
450- default :
451- if err := compareValues (mt , key , val , actualVal ); err != nil {
452- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
476+ eElems , err := expected .Reply .Elements ()
477+ if err != nil {
478+ return fmt .Errorf ("error getting expected reply elements: %s" , err )
479+ }
480+
481+ for _ , elem := range eElems {
482+ key := elem .Key ()
483+ val := elem .Value ()
484+ actualVal := evt .Reply .Lookup (key )
485+
486+ switch key {
487+ case "writeErrors" :
488+ if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
489+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
490+ }
491+ default :
492+ if err := compareValues (mt , key , val , actualVal ); err != nil {
493+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
494+ }
453495 }
454496 }
497+ return nil
498+ }
499+ for idx , expected := range expectations {
500+ err := compare (expected )
501+ if err != nil {
502+ return fmt .Errorf ("error at index %d: %s" , idx , err )
503+ }
455504 }
456505 return nil
457506}
458507
459- func compareFailedEvent (mt * mtest.T , expectation * expectation ) error {
508+ func compareFailedEvents (mt * mtest.T , expectations [] * cmdFailedEvt ) error {
460509 mt .Helper ()
461510
462- expected := expectation .CommandFailedEvent
463- if len (expected .Extra ) > 0 {
464- return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
465- }
466- evt := mt .GetFailedEvent ()
467- if evt == nil {
468- return errors .New ("expected CommandFailedEvent, got nil" )
511+ expectedCmds := make (map [string ]bool )
512+ for _ , expected := range expectations {
513+ expectedCmds [expected .CommandName ] = true
469514 }
470515
471- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
472- return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
516+ compare := func (expected * cmdFailedEvt ) error {
517+ if expected == nil {
518+ return nil
519+ }
520+ if len (expected .Extra ) > 0 {
521+ return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
522+ }
523+
524+ var evt * event.CommandFailedEvent
525+ // skip events not in expectations
526+ for {
527+ evt = mt .GetFailedEvent ()
528+ if evt == nil {
529+ return errors .New ("expected CommandFailedEvent, got nil" )
530+ }
531+ if v , ok := expectedCmds [expected .CommandName ]; ok && v {
532+ break
533+ }
534+ }
535+
536+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
537+ return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
538+ }
539+ return nil
540+ }
541+ for idx , expected := range expectations {
542+ err := compare (expected )
543+ if err != nil {
544+ return fmt .Errorf ("error at index %d: %s" , idx , err )
545+ }
473546 }
474547 return nil
475548}
0 commit comments