44 "context"
55 "errors"
66 "fmt"
7+ "sync/atomic"
78 "time"
89
910 "github.com/go-redis/redis/v8/internal"
@@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline(
130131}
131132
132133func (hs hooks ) withContext (ctx context.Context , fn func () error ) error {
133- done := ctx .Done ()
134- if done == nil {
135- return fn ()
136- }
137-
138- errc := make (chan error , 1 )
139- go func () { errc <- fn () }()
140-
141- select {
142- case <- done :
143- return ctx .Err ()
144- case err := <- errc :
145- return err
146- }
134+ return fn ()
147135}
148136
149137//------------------------------------------------------------------------------
@@ -316,8 +304,24 @@ func (c *baseClient) withConn(
316304 c .releaseConn (ctx , cn , err )
317305 }()
318306
319- err = fn (ctx , cn )
320- return err
307+ done := ctx .Done ()
308+ if done == nil {
309+ err = fn (ctx , cn )
310+ return err
311+ }
312+
313+ errc := make (chan error , 1 )
314+ go func () { errc <- fn (ctx , cn ) }()
315+
316+ select {
317+ case <- done :
318+ _ = cn .Close ()
319+
320+ err = ctx .Err ()
321+ return err
322+ case err = <- errc :
323+ return err
324+ }
321325 })
322326}
323327
@@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
334338 }
335339 }
336340
337- retryTimeout := true
341+ retryTimeout := uint32 ( 1 )
338342 err := c .withConn (ctx , func (ctx context.Context , cn * pool.Conn ) error {
339343 err := cn .WithWriter (ctx , c .opt .WriteTimeout , func (wr * proto.Writer ) error {
340344 return writeCmd (wr , cmd )
@@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
345349
346350 err = cn .WithReader (ctx , c .cmdTimeout (cmd ), cmd .readReply )
347351 if err != nil {
348- retryTimeout = cmd .readTimeout () == nil
352+ if cmd .readTimeout () == nil {
353+ atomic .StoreUint32 (& retryTimeout , 1 )
354+ }
349355 return err
350356 }
351357
@@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
354360 if err == nil {
355361 return nil
356362 }
357- retry = shouldRetry (err , retryTimeout )
363+ retry = shouldRetry (err , atomic . LoadUint32 ( & retryTimeout ) == 1 )
358364 return err
359365 })
360366 if err == nil || ! retry {
0 commit comments