Skip to content
Merged
85 changes: 64 additions & 21 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,54 @@ type Error interface {
var _ Error = proto.RedisError("")

func isContextError(err error) bool {
switch err {
case context.Canceled, context.DeadlineExceeded:
return true
default:
return false
// Check for wrapped context errors using errors.Is
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}

// isTimeoutError checks if an error is a timeout error, even if wrapped.
// Returns (isTimeout, shouldRetryOnTimeout) where:
// - isTimeout: true if the error is any kind of timeout error
// - shouldRetryOnTimeout: true if Timeout() method returns true
func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) {
// Check for timeoutError interface (works with wrapped errors)
var te timeoutError
if errors.As(err, &te) {
return true, te.Timeout()
}

// Check for net.Error specifically (common case for network timeouts)
var netErr net.Error
if errors.As(err, &netErr) {
return true, netErr.Timeout()
}

return false, false
}

func shouldRetry(err error, retryTimeout bool) bool {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
if err == nil {
return false
}

// Check for EOF errors (works with wrapped errors)
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return true
case nil, context.Canceled, context.DeadlineExceeded:
}

// Check for context errors (works with wrapped errors)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
case pool.ErrPoolTimeout:
}

// Check for pool timeout (works with wrapped errors)
if errors.Is(err, pool.ErrPoolTimeout) {
// connection pool timeout, increase retries. #3289
return true
}

if v, ok := err.(timeoutError); ok {
if v.Timeout() {
// Check for timeout errors (works with wrapped errors)
if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout {
if hasTimeoutFlag {
return retryTimeout
}
return true
Expand Down Expand Up @@ -115,23 +142,37 @@ func shouldRetry(err error, retryTimeout bool) bool {
if strings.HasPrefix(s, "TRYAGAIN ") {
return true
}
if strings.HasPrefix(s, "MASTERDOWN ") {
return true
}

return false
}

func isRedisError(err error) bool {
_, ok := err.(proto.RedisError)
return ok
// Check if error implements the Error interface (works with wrapped errors)
var redisErr Error
if errors.As(err, &redisErr) {
return true
}
// Also check for proto.RedisError specifically
var protoRedisErr proto.RedisError
return errors.As(err, &protoRedisErr)
}

func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
case pool.ErrConnUnusableTimeout:
return true
if err == nil {
return false
}

// Check for context errors (works with wrapped errors)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}

// Check for pool timeout errors (works with wrapped errors)
if errors.Is(err, pool.ErrConnUnusableTimeout) {
return true
}

if isRedisError(err) {
Expand All @@ -151,7 +192,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
}

if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Check for network timeout errors (works with wrapped errors)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return false
}
}
Expand Down
201 changes: 192 additions & 9 deletions error_wrapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"io"
"strings"
"testing"

"github.com/redis/go-redis/v9"
Expand Down Expand Up @@ -443,17 +445,198 @@ func TestCustomErrorTypeWrapping(t *testing.T) {
}
}

// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
// TestTimeoutErrorWrapping tests that timeout errors work correctly when wrapped
func TestTimeoutErrorWrapping(t *testing.T) {
// Test 1: Wrapped timeoutError interface
t.Run("Wrapped timeoutError with Timeout()=true", func(t *testing.T) {
timeoutErr := &testTimeoutError{timeout: true, msg: "i/o timeout"}
wrappedErr := fmt.Errorf("hook wrapper: %w", timeoutErr)
doubleWrappedErr := fmt.Errorf("another wrapper: %w", wrappedErr)

// Should NOT retry when retryTimeout=false
if redis.ShouldRetry(doubleWrappedErr, false) {
t.Errorf("Should not retry timeout error when retryTimeout=false")
}

// Should retry when retryTimeout=true
if !redis.ShouldRetry(doubleWrappedErr, true) {
t.Errorf("Should retry timeout error when retryTimeout=true")
}
})

// Test 2: Wrapped timeoutError with Timeout()=false
t.Run("Wrapped timeoutError with Timeout()=false", func(t *testing.T) {
timeoutErr := &testTimeoutError{timeout: false, msg: "connection error"}
wrappedErr := fmt.Errorf("hook wrapper: %w", timeoutErr)

// Should always retry when Timeout()=false
if !redis.ShouldRetry(wrappedErr, false) {
t.Errorf("Should retry non-timeout error even when retryTimeout=false")
}
if !redis.ShouldRetry(wrappedErr, true) {
t.Errorf("Should retry non-timeout error when retryTimeout=true")
}
})

// Test 3: Wrapped net.Error with Timeout()=true
t.Run("Wrapped net.Error", func(t *testing.T) {
netErr := &testNetError{timeout: true, temporary: true, msg: "network timeout"}
wrappedErr := fmt.Errorf("hook context: %w", netErr)

// Should respect retryTimeout parameter
if redis.ShouldRetry(wrappedErr, false) {
t.Errorf("Should not retry network timeout when retryTimeout=false")
}
if !redis.ShouldRetry(wrappedErr, true) {
t.Errorf("Should retry network timeout when retryTimeout=true")
}
})

// Test 4: Multiple levels of wrapping
t.Run("Multiple levels of wrapping", func(t *testing.T) {
timeoutErr := &testTimeoutError{timeout: true, msg: "timeout"}
customErr := &AppError{
Code: "TIMEOUT_ERROR",
Message: "Operation timed out",
RequestID: "req-timeout-123",
Err: timeoutErr,
}
wrappedErr := fmt.Errorf("hook wrapper: %w", customErr)

// Should still detect timeout through multiple wrappers
if redis.ShouldRetry(wrappedErr, false) {
t.Errorf("Should not retry timeout through custom error when retryTimeout=false")
}
if !redis.ShouldRetry(wrappedErr, true) {
t.Errorf("Should retry timeout through custom error when retryTimeout=true")
}

// Should be able to extract custom error
var appErr *AppError
if !errors.As(wrappedErr, &appErr) {
t.Errorf("Should be able to extract AppError from wrapped error")
}
})
}

// testTimeoutError implements the timeoutError interface for testing
type testTimeoutError struct {
timeout bool
msg string
}

func (e *testTimeoutError) Error() string {
return e.msg
}

func (e *testTimeoutError) Timeout() bool {
return e.timeout
}

// testNetError implements net.Error for testing
type testNetError struct {
timeout bool
temporary bool
msg string
}

func (e *testNetError) Error() string {
return e.msg
}

func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
func (e *testNetError) Timeout() bool {
return e.timeout
}

func (e *testNetError) Temporary() bool {
return e.temporary
}

// TestContextErrorWrapping tests that context errors work correctly when wrapped
func TestContextErrorWrapping(t *testing.T) {
t.Run("Wrapped context.Canceled", func(t *testing.T) {
wrappedErr := fmt.Errorf("operation failed: %w", context.Canceled)
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)

// Should NOT retry
if redis.ShouldRetry(doubleWrappedErr, false) {
t.Errorf("Should not retry wrapped context.Canceled")
}
}
return false
if redis.ShouldRetry(doubleWrappedErr, true) {
t.Errorf("Should not retry wrapped context.Canceled even with retryTimeout=true")
}
})

t.Run("Wrapped context.DeadlineExceeded", func(t *testing.T) {
wrappedErr := fmt.Errorf("timeout: %w", context.DeadlineExceeded)
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)

// Should NOT retry
if redis.ShouldRetry(doubleWrappedErr, false) {
t.Errorf("Should not retry wrapped context.DeadlineExceeded")
}
if redis.ShouldRetry(doubleWrappedErr, true) {
t.Errorf("Should not retry wrapped context.DeadlineExceeded even with retryTimeout=true")
}
})
}

// TestIOErrorWrapping tests that io errors work correctly when wrapped
func TestIOErrorWrapping(t *testing.T) {
t.Run("Wrapped io.EOF", func(t *testing.T) {
wrappedErr := fmt.Errorf("read failed: %w", io.EOF)
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)

// Should retry
if !redis.ShouldRetry(doubleWrappedErr, false) {
t.Errorf("Should retry wrapped io.EOF")
}
})

t.Run("Wrapped io.ErrUnexpectedEOF", func(t *testing.T) {
wrappedErr := fmt.Errorf("read failed: %w", io.ErrUnexpectedEOF)

// Should retry
if !redis.ShouldRetry(wrappedErr, false) {
t.Errorf("Should retry wrapped io.ErrUnexpectedEOF")
}
})
}

// TestPoolErrorWrapping tests that pool errors work correctly when wrapped
func TestPoolErrorWrapping(t *testing.T) {
t.Run("Wrapped pool.ErrPoolTimeout", func(t *testing.T) {
wrappedErr := fmt.Errorf("connection failed: %w", redis.ErrPoolTimeout)
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)

// Should retry
if !redis.ShouldRetry(doubleWrappedErr, false) {
t.Errorf("Should retry wrapped pool.ErrPoolTimeout")
}
})
}

// TestRedisErrorWrapping tests that RedisError detection works with wrapped errors
func TestRedisErrorWrapping(t *testing.T) {
t.Run("Wrapped proto.RedisError", func(t *testing.T) {
redisErr := proto.RedisError("ERR something went wrong")
wrappedErr := fmt.Errorf("command failed: %w", redisErr)
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)

// Create a command and set the wrapped error
cmd := redis.NewStatusCmd(context.Background(), "GET", "key")
cmd.SetErr(doubleWrappedErr)

// The error should still be recognized as a Redis error
// This is tested indirectly through the typed error system
if !strings.Contains(cmd.Err().Error(), "ERR something went wrong") {
t.Errorf("Error message not preserved through wrapping")
}
})
}

// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}

Loading
Loading