Skip to content

Commit 21bd243

Browse files
committed
improve reauth state management. fix tests
1 parent 92433e6 commit 21bd243

File tree

3 files changed

+24
-46
lines changed

3 files changed

+24
-46
lines changed

internal/auth/streaming/pool_hook.go

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -179,40 +179,27 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
179179
r.workers <- struct{}{}
180180
}()
181181

182-
var err error
183-
timeout := time.After(r.reAuthTimeout)
182+
// Create timeout context for connection acquisition
183+
// This prevents indefinite waiting if the connection is stuck
184+
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
185+
defer cancel()
184186

185187
// Try to acquire the connection for re-authentication
186188
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
187189
// This prevents re-authentication from interfering with active commands
188-
const baseDelay = 10 * time.Microsecond
189-
acquired := false
190-
attempt := 0
191-
for !acquired {
192-
select {
193-
case <-timeout:
194-
// Timeout occurred, cannot acquire connection
195-
err = pool.ErrConnUnusableTimeout
196-
reAuthFn(err)
197-
return
198-
default:
199-
// Try to atomically transition from IDLE to UNUSABLE
200-
// This ensures we only acquire connections that are not actively in use
201-
stateMachine := conn.GetStateMachine()
202-
if stateMachine != nil {
203-
_, err := stateMachine.TryTransition([]pool.ConnState{pool.StateIdle}, pool.StateUnusable)
204-
if err == nil {
205-
// Successfully acquired: connection was IDLE, now UNUSABLE
206-
acquired = true
207-
}
208-
}
209-
if !acquired {
210-
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
211-
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
212-
time.Sleep(delay)
213-
attempt++
214-
}
215-
}
190+
// Use AwaitAndTransition to wait for the connection to become IDLE
191+
stateMachine := conn.GetStateMachine()
192+
if stateMachine == nil {
193+
// No state machine - should not happen, but handle gracefully
194+
reAuthFn(pool.ErrConnUnusableTimeout)
195+
return
196+
}
197+
198+
_, err := stateMachine.AwaitAndTransition(ctx, []pool.ConnState{pool.StateIdle}, pool.StateUnusable)
199+
if err != nil {
200+
// Timeout or other error occurred, cannot acquire connection
201+
reAuthFn(err)
202+
return
216203
}
217204

218205
// safety first
@@ -222,10 +209,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
222209
}
223210

224211
// Release the connection: transition from UNUSABLE back to IDLE
225-
stateMachine := conn.GetStateMachine()
226-
if stateMachine != nil {
227-
stateMachine.Transition(pool.StateIdle)
228-
}
212+
stateMachine.Transition(pool.StateIdle)
229213
}()
230214
}
231215

maintnotifications/handoff_worker.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,6 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
256256
// Get handoff info atomically to prevent race conditions
257257
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
258258

259-
// Special case: empty endpoint means clear handoff state
260-
if endpoint == "" {
261-
conn.ClearHandoffState()
262-
return nil
263-
}
264-
265259
// on retries the connection will not be marked for handoff, but it will have retries > 0
266260
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
267261
if !shouldHandoff && conn.HandoffRetries() == 0 {

maintnotifications/pool_hook_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,8 @@ func TestConnectionHook(t *testing.T) {
391391

392392
ctx := context.Background()
393393
acceptCon, err := processor.OnGet(ctx, conn, false)
394-
if err != ErrConnectionMarkedForHandoff {
395-
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
394+
if err != ErrConnectionMarkedForHandoffWithState {
395+
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
396396
}
397397
if acceptCon {
398398
t.Error("Connection should not be accepted when marked for handoff")
@@ -425,8 +425,8 @@ func TestConnectionHook(t *testing.T) {
425425
// Test OnGet with pending handoff
426426
ctx := context.Background()
427427
acceptCon, err := processor.OnGet(ctx, conn, false)
428-
if err != ErrConnectionMarkedForHandoff {
429-
t.Error("Should return ErrConnectionMarkedForHandoff for pending connection")
428+
if err != ErrConnectionMarkedForHandoffWithState {
429+
t.Errorf("Should return ErrConnectionMarkedForHandoffWithState for pending connection, got %v", err)
430430
}
431431
if acceptCon {
432432
t.Error("Should not accept connection with pending handoff")
@@ -678,8 +678,8 @@ func TestConnectionHook(t *testing.T) {
678678
if err == nil {
679679
t.Error("OnGet should fail for connection marked for handoff")
680680
}
681-
if err != ErrConnectionMarkedForHandoff {
682-
t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err)
681+
if err != ErrConnectionMarkedForHandoffWithState {
682+
t.Errorf("Expected ErrConnectionMarkedForHandoffWithState, got %v", err)
683683
}
684684
if acceptConn {
685685
t.Error("Connection should not be accepted when marked for handoff")

0 commit comments

Comments
 (0)