diff --git a/internal/integration/unified/client_entity.go b/internal/integration/unified/client_entity.go index bc981793df..325138c02a 100644 --- a/internal/integration/unified/client_entity.go +++ b/internal/integration/unified/client_entity.go @@ -38,6 +38,52 @@ var securitySensitiveCommands = []string{ "createUser", "updateUser", "copydbgetnonce", "copydbsaslstart", "copydb", } +// eventSequencer allows for sequence-based event filtering for +// awaitMinPoolSizeMS support. +// +// Per the unified test format spec, when awaitMinPoolSizeMS is specified, any +// CMAP and SDAM events that occur during connection pool initialization +// (before minPoolSize is reached) must be ignored. We track this by +// assigning a monotonically increasing sequence number to each event as it's +// recorded. After pool initialization completes, we set eventCutoffSeq to the +// current sequence number. Event accessors for CMAP and SDAM types then +// filter out any events with sequence <= eventCutoffSeq. +// +// Sequencing is thread-safe to support concurrent operations that may generate +// events (e.g., connection checkouts generating CMAP events). +type eventSequencer struct { + counter atomic.Int64 + cutoff atomic.Int64 + + mu sync.RWMutex + + // pool events are heterogeneous, so we track their sequence separately + poolSeq []int64 + seqByEventType map[monitoringEventType][]int64 +} + +// setCutoff marks the current sequence as the filtering cutoff point. +func (es *eventSequencer) setCutoff() { + es.cutoff.Store(es.counter.Load()) +} + +// recordEvent stores the sequence number for a given event type. +func (es *eventSequencer) recordEvent(eventType monitoringEventType) { + next := es.counter.Add(1) + + es.mu.Lock() + es.seqByEventType[eventType] = append(es.seqByEventType[eventType], next) + es.mu.Unlock() +} + +func (es *eventSequencer) recordPooledEvent() { + next := es.counter.Add(1) + + es.mu.Lock() + es.poolSeq = append(es.poolSeq, next) + es.mu.Unlock() +} + // clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test // execution. type clientEntity struct { @@ -72,30 +118,8 @@ type clientEntity struct { entityMap *EntityMap - logQueue chan orderedLogMessage -} - -// awaitMinimumPoolSize waits for the client's connection pool to reach the -// specified minimum size. This is a best effort operation that times out after -// some predefined amount of time to avoid blocking tests indefinitely. -func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64) error { - // Don't spend longer than 500ms awaiting minPoolSize. - awaitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-awaitCtx.Done(): - return fmt.Errorf("timed out waiting for client to reach minPoolSize") - case <-ticker.C: - if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize { - return nil - } - } - } + logQueue chan orderedLogMessage + eventSequencer eventSequencer } func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOptions) (*clientEntity, error) { @@ -118,6 +142,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp serverDescriptionChangedEventsCount: make(map[serverDescriptionChangedEventInfo]int32), entityMap: em, observeSensitiveCommands: entityOptions.ObserveSensitiveCommands, + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, } entity.setRecordEvents(true) @@ -226,8 +253,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp return nil, fmt.Errorf("error creating mongo.Client: %w", err) } - if entityOptions.AwaitMinPoolSize && clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 { - if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize); err != nil { + if entityOptions.AwaitMinPoolSizeMS != nil && *entityOptions.AwaitMinPoolSizeMS > 0 && + clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 { + if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize, *entityOptions.AwaitMinPoolSizeMS); err != nil { return nil, err } } @@ -326,8 +354,47 @@ func (c *clientEntity) failedEvents() []*event.CommandFailedEvent { return events } -func (c *clientEntity) poolEvents() []*event.PoolEvent { - return c.pooled +// filterEventsBySeq filters events by sequence number for the given eventType. +// See comments on eventSequencer for more details. +func filterEventsBySeq[T any](c *clientEntity, events []T, eventType monitoringEventType) []T { + cutoff := c.eventSequencer.cutoff.Load() + if cutoff == 0 { + return events + } + + // Lock order: eventProcessMu -> eventSequencer.mu (matches writers) + c.eventProcessMu.RLock() + c.eventSequencer.mu.RLock() + + // Snapshot to minimize time under locks and avoid races + localEvents := append([]T(nil), events...) + + var seqSlice []int64 + if eventType == poolAnyEvent { + seqSlice = c.eventSequencer.poolSeq + } else { + seqSlice = c.eventSequencer.seqByEventType[eventType] + } + + localSeqs := append([]int64(nil), seqSlice...) + + c.eventSequencer.mu.RUnlock() + c.eventProcessMu.RUnlock() + + // guard against index out of range. + n := len(localEvents) + if len(localSeqs) < n { + n = len(localSeqs) + } + + filtered := make([]T, 0, n) + for i := 0; i < n; i++ { + if localSeqs[i] > cutoff { + filtered = append(filtered, localEvents[i]) + } + } + + return filtered } func (c *clientEntity) numberConnectionsCheckedOut() int32 { @@ -516,7 +583,10 @@ func (c *clientEntity) processPoolEvent(evt *event.PoolEvent) { eventType := monitoringEventTypeFromPoolEvent(evt) if _, ok := c.observedEvents[eventType]; ok { + c.eventProcessMu.Lock() c.pooled = append(c.pooled, evt) + c.eventSequencer.recordPooledEvent() + c.eventProcessMu.Unlock() } c.addEventsCount(eventType) @@ -539,6 +609,7 @@ func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDes if _, ok := c.observedEvents[serverDescriptionChangedEvent]; ok { c.serverDescriptionChanged = append(c.serverDescriptionChanged, evt) + c.eventSequencer.recordEvent(serverDescriptionChangedEvent) } // Record object-specific unified spec test data on an event. @@ -558,6 +629,7 @@ func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartb if _, ok := c.observedEvents[serverHeartbeatFailedEvent]; ok { c.serverHeartbeatFailedEvent = append(c.serverHeartbeatFailedEvent, evt) + c.eventSequencer.recordEvent(serverHeartbeatFailedEvent) } c.addEventsCount(serverHeartbeatFailedEvent) @@ -573,6 +645,7 @@ func (c *clientEntity) processServerHeartbeatStartedEvent(evt *event.ServerHeart if _, ok := c.observedEvents[serverHeartbeatStartedEvent]; ok { c.serverHeartbeatStartedEvent = append(c.serverHeartbeatStartedEvent, evt) + c.eventSequencer.recordEvent(serverHeartbeatStartedEvent) } c.addEventsCount(serverHeartbeatStartedEvent) @@ -588,6 +661,7 @@ func (c *clientEntity) processServerHeartbeatSucceededEvent(evt *event.ServerHea if _, ok := c.observedEvents[serverHeartbeatSucceededEvent]; ok { c.serverHeartbeatSucceeded = append(c.serverHeartbeatSucceeded, evt) + c.eventSequencer.recordEvent(serverHeartbeatSucceededEvent) } c.addEventsCount(serverHeartbeatSucceededEvent) @@ -603,6 +677,7 @@ func (c *clientEntity) processTopologyDescriptionChangedEvent(evt *event.Topolog if _, ok := c.observedEvents[topologyDescriptionChangedEvent]; ok { c.topologyDescriptionChanged = append(c.topologyDescriptionChanged, evt) + c.eventSequencer.recordEvent(topologyDescriptionChangedEvent) } c.addEventsCount(topologyDescriptionChangedEvent) @@ -618,6 +693,7 @@ func (c *clientEntity) processTopologyOpeningEvent(evt *event.TopologyOpeningEve if _, ok := c.observedEvents[topologyOpeningEvent]; ok { c.topologyOpening = append(c.topologyOpening, evt) + c.eventSequencer.recordEvent(topologyOpeningEvent) } c.addEventsCount(topologyOpeningEvent) @@ -633,6 +709,7 @@ func (c *clientEntity) processTopologyClosedEvent(evt *event.TopologyClosedEvent if _, ok := c.observedEvents[topologyClosedEvent]; ok { c.topologyClosed = append(c.topologyClosed, evt) + c.eventSequencer.recordEvent(topologyClosedEvent) } c.addEventsCount(topologyClosedEvent) @@ -724,3 +801,27 @@ func evaluateUseMultipleMongoses(clientOpts *options.ClientOptions, useMultipleM } return nil } + +// awaitMinimumPoolSize waits for the client's connection pool to reach the +// specified minimum size, then clears all CMAP and SDAM events that occurred +// during pool initialization. +func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64, timeoutMS int) error { + awaitCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutMS)*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-awaitCtx.Done(): + return fmt.Errorf("timed out waiting for client to reach minPoolSize") + case <-ticker.C: + if uint64(entity.getEventCount(connectionReadyEvent)) >= minPoolSize { + entity.eventSequencer.setCutoff() + + return nil + } + } + } +} diff --git a/internal/integration/unified/client_entity_test.go b/internal/integration/unified/client_entity_test.go new file mode 100644 index 0000000000..5100ce6c3b --- /dev/null +++ b/internal/integration/unified/client_entity_test.go @@ -0,0 +1,171 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package unified + +import ( + "testing" + + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +// Helper functions to condense event recording in tests +func recordPoolEvent(c *clientEntity) { + c.pooled = append(c.pooled, &event.PoolEvent{}) + c.eventSequencer.recordPooledEvent() +} + +func recordServerDescChanged(c *clientEntity) { + c.serverDescriptionChanged = append(c.serverDescriptionChanged, &event.ServerDescriptionChangedEvent{}) + c.eventSequencer.recordEvent(serverDescriptionChangedEvent) +} + +func recordTopologyOpening(c *clientEntity) { + c.topologyOpening = append(c.topologyOpening, &event.TopologyOpeningEvent{}) + c.eventSequencer.recordEvent(topologyOpeningEvent) +} + +func Test_eventSequencer(t *testing.T) { + tests := []struct { + name string + setupEvents func(*clientEntity) + cutoffAfter int // Set cutoff after this many events (0 = no cutoff) + expectedPooled int + expectedSDAM map[monitoringEventType]int + }{ + { + name: "no cutoff filters nothing", + cutoffAfter: 0, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + recordPoolEvent(c) + recordServerDescChanged(c) + recordServerDescChanged(c) + }, + expectedPooled: 3, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 2, + }, + }, + { + name: "cutoff after 2 pool events filters first 2", + cutoffAfter: 2, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + // Cutoff will be set here (after event 2) + recordPoolEvent(c) + recordPoolEvent(c) + recordPoolEvent(c) + }, + expectedPooled: 3, // Events 3, 4, 5 + expectedSDAM: map[monitoringEventType]int{}, + }, + { + name: "cutoff filters mixed pool and SDAM events", + cutoffAfter: 4, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordServerDescChanged(c) + recordPoolEvent(c) + recordTopologyOpening(c) + // Cutoff will be set here (after event 4) + recordPoolEvent(c) + recordServerDescChanged(c) + recordTopologyOpening(c) + }, + expectedPooled: 1, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 1, + topologyOpeningEvent: 1, + }, + }, + { + name: "cutoff after all events filters everything", + cutoffAfter: 3, + setupEvents: func(c *clientEntity) { + recordPoolEvent(c) + recordPoolEvent(c) + recordServerDescChanged(c) + // Cutoff will be set here (after all 3 events) + }, + expectedPooled: 0, + expectedSDAM: map[monitoringEventType]int{ + serverDescriptionChangedEvent: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal clientEntity + client := &clientEntity{ + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, + } + + // Setup events + tt.setupEvents(client) + + // Set cutoff if specified + if tt.cutoffAfter > 0 { + // Manually set cutoff to the specified event sequence + client.eventSequencer.cutoff.Store(int64(tt.cutoffAfter)) + } + + // Test pool event filtering + filteredPool := filterEventsBySeq(client, client.pooled, poolAnyEvent) + assert.Equal(t, tt.expectedPooled, len(filteredPool), "pool events count mismatch") + + // Test SDAM event filtering + for eventType, expectedCount := range tt.expectedSDAM { + var actualCount int + + switch eventType { + case serverDescriptionChangedEvent: + actualCount = len(filterEventsBySeq(client, client.serverDescriptionChanged, serverDescriptionChangedEvent)) + case serverHeartbeatSucceededEvent: + actualCount = len(filterEventsBySeq(client, client.serverHeartbeatSucceeded, serverHeartbeatSucceededEvent)) + case topologyOpeningEvent: + actualCount = len(filterEventsBySeq(client, client.topologyOpening, topologyOpeningEvent)) + } + + assert.Equal(t, expectedCount, actualCount, "%s count mismatch", eventType) + } + }) + } +} + +func Test_eventSequencer_setCutoff(t *testing.T) { + client := &clientEntity{ + eventSequencer: eventSequencer{ + seqByEventType: make(map[monitoringEventType][]int64), + }, + } + + // Record some events + recordPoolEvent(client) + recordPoolEvent(client) + + // Verify counter is at 2 + assert.Equal(t, int64(2), client.eventSequencer.counter.Load(), "counter should be 2") + + // Set cutoff + client.eventSequencer.setCutoff() + + // Verify cutoff matches counter + assert.Equal(t, int64(2), client.eventSequencer.cutoff.Load(), "cutoff should be 2") + + // Record more events + recordPoolEvent(client) + + // Verify counter incremented but cutoff didn't + assert.Equal(t, int64(3), client.eventSequencer.counter.Load(), "counter should be 3") + assert.Equal(t, int64(2), client.eventSequencer.cutoff.Load(), "cutoff should still be 2") +} diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index b1b827a124..8222233290 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -83,11 +83,10 @@ type entityOptions struct { ClientEncryptionOpts *clientEncryptionOpts `bson:"clientEncryptionOpts"` - // If true, the unified spec runner must wait for the connection pool to be - // populated for all servers according to the minPoolSize option. If false, - // not specified, or if minPoolSize equals 0, there is no need to wait for any - // specific pool state. - AwaitMinPoolSize bool `bson:"awaitMinPoolSize"` + // Maximum duration (in milliseconds) that the test runner MUST wait for each + // connection pool to be populated with minPoolSize. Any CMAP and SDAM events + // that occur before the pool is populated will be ignored. + AwaitMinPoolSizeMS *int `bson:"awaitMinPoolSizeMS"` } func (eo *entityOptions) setHeartbeatFrequencyMS(freq time.Duration) { diff --git a/internal/integration/unified/event.go b/internal/integration/unified/event.go index abbec74439..94f0220a24 100644 --- a/internal/integration/unified/event.go +++ b/internal/integration/unified/event.go @@ -37,6 +37,9 @@ const ( topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent" topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent" topologyClosedEvent monitoringEventType = "TopologyClosedEvent" + + // sentinel: indicates "use pooled (CMAP) sequence". + poolAnyEvent monitoringEventType = "_PoolAny" ) func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) { diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 0521f0653e..8fec9ddb70 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -312,7 +312,7 @@ func verifyCommandEvents(ctx context.Context, client *clientEntity, expectedEven } func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) error { - pooled := client.poolEvents() + pooled := filterEventsBySeq(client, client.pooled, poolAnyEvent) if len(expectedEvents.CMAPEvents) == 0 && len(pooled) != 0 { return fmt.Errorf("expected no cmap events to be sent but got %s", stringifyEventsForClient(client)) } @@ -443,7 +443,7 @@ func stringifyEventsForClient(client *clientEntity) string { } str.WriteString("\nPool Events\n\n") - for _, evt := range client.poolEvents() { + for _, evt := range filterEventsBySeq(client, client.pooled, poolAnyEvent) { str.WriteString(fmt.Sprintf("[%s] Event Type: %q\n", evt.Address, evt.Type)) } @@ -522,13 +522,13 @@ func getNextTopologyClosedEvent( func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error { var ( - changed = client.serverDescriptionChanged - started = client.serverHeartbeatStartedEvent - succeeded = client.serverHeartbeatSucceeded - failed = client.serverHeartbeatFailedEvent - tchanged = client.topologyDescriptionChanged - topening = client.topologyOpening - tclosed = client.topologyClosed + changed = filterEventsBySeq(client, client.serverDescriptionChanged, serverDescriptionChangedEvent) + started = filterEventsBySeq(client, client.serverHeartbeatStartedEvent, serverHeartbeatStartedEvent) + succeeded = filterEventsBySeq(client, client.serverHeartbeatSucceeded, serverHeartbeatSucceededEvent) + failed = filterEventsBySeq(client, client.serverHeartbeatFailedEvent, serverHeartbeatFailedEvent) + tchanged = filterEventsBySeq(client, client.topologyDescriptionChanged, topologyDescriptionChangedEvent) + topening = filterEventsBySeq(client, client.topologyOpening, topologyOpeningEvent) + tclosed = filterEventsBySeq(client, client.topologyClosed, topologyClosedEvent) ) vol := func() int {