@@ -5,6 +5,7 @@ package tarantool
55import (
66 "bufio"
77 "bytes"
8+ "context"
89 "errors"
910 "fmt"
1011 "io"
@@ -125,8 +126,11 @@ type Connection struct {
125126 c net.Conn
126127 mutex sync.Mutex
127128 // Schema contains schema loaded on connection.
128- Schema * Schema
129+ Schema * Schema
130+ // requestId contains the last request ID for requests with nil context.
129131 requestId uint32
132+ // contextRequestId contains the last request ID for requests with context.
133+ contextRequestId uint32
130134 // Greeting contains first message sent by Tarantool.
131135 Greeting * Greeting
132136
@@ -143,16 +147,57 @@ type Connection struct {
143147
144148var _ = Connector (& Connection {}) // Check compatibility with connector interface.
145149
150+ type futureList struct {
151+ first * Future
152+ last * * Future
153+ }
154+
155+ func (list * futureList ) findFuture (reqid uint32 , fetch bool ) * Future {
156+ root := & list .first
157+ for {
158+ fut := * root
159+ if fut == nil {
160+ return nil
161+ }
162+ if fut .requestId == reqid {
163+ if fetch {
164+ * root = fut .next
165+ if fut .next == nil {
166+ list .last = root
167+ } else {
168+ fut .next = nil
169+ }
170+ }
171+ return fut
172+ }
173+ root = & fut .next
174+ }
175+ }
176+
177+ func (list * futureList ) addFuture (fut * Future ) {
178+ * list .last = fut
179+ list .last = & fut .next
180+ }
181+
182+ func (list * futureList ) clear (err error , conn * Connection ) {
183+ fut := list .first
184+ list .first = nil
185+ list .last = & list .first
186+ for fut != nil {
187+ fut .SetError (err )
188+ conn .markDone (fut )
189+ fut , fut .next = fut .next , nil
190+ }
191+ }
192+
146193type connShard struct {
147- rmut sync.Mutex
148- requests [requestsMap ]struct {
149- first * Future
150- last * * Future
151- }
152- bufmut sync.Mutex
153- buf smallWBuf
154- enc * msgpack.Encoder
155- _pad [16 ]uint64 //nolint: unused,structcheck
194+ rmut sync.Mutex
195+ requests [requestsMap ]futureList
196+ requestsWithCtx [requestsMap ]futureList
197+ bufmut sync.Mutex
198+ buf smallWBuf
199+ enc * msgpack.Encoder
200+ _pad [16 ]uint64 //nolint: unused,structcheck
156201}
157202
158203// Greeting is a message sent by Tarantool on connect.
@@ -167,6 +212,11 @@ type Opts struct {
167212 // push messages are received. If Timeout is zero, any request can be
168213 // blocked infinitely.
169214 // Also used to setup net.TCPConn.Set(Read|Write)Deadline.
215+ //
216+ // Pay attention, when using contexts with request objects,
217+ // the timeout option for Connection does not affect the lifetime
218+ // of the request. For those purposes use context.WithTimeout() as
219+ // the root context.
170220 Timeout time.Duration
171221 // Timeout between reconnect attempts. If Reconnect is zero, no
172222 // reconnect attempts will be made.
@@ -262,12 +312,13 @@ type SslOpts struct {
262312// and will not finish to make attempts on authorization failures.
263313func Connect (addr string , opts Opts ) (conn * Connection , err error ) {
264314 conn = & Connection {
265- addr : addr ,
266- requestId : 0 ,
267- Greeting : & Greeting {},
268- control : make (chan struct {}),
269- opts : opts ,
270- dec : msgpack .NewDecoder (& smallBuf {}),
315+ addr : addr ,
316+ requestId : 0 ,
317+ contextRequestId : 1 ,
318+ Greeting : & Greeting {},
319+ control : make (chan struct {}),
320+ opts : opts ,
321+ dec : msgpack .NewDecoder (& smallBuf {}),
271322 }
272323 maxprocs := uint32 (runtime .GOMAXPROCS (- 1 ))
273324 if conn .opts .Concurrency == 0 || conn .opts .Concurrency > maxprocs * 128 {
@@ -283,8 +334,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
283334 conn .shard = make ([]connShard , conn .opts .Concurrency )
284335 for i := range conn .shard {
285336 shard := & conn .shard [i ]
286- for j := range shard .requests {
287- shard .requests [j ].last = & shard .requests [j ].first
337+ requestsLists := []* [requestsMap ]futureList {& shard .requests , & shard .requestsWithCtx }
338+ for _ , requests := range requestsLists {
339+ for j := range requests {
340+ requests [j ].last = & requests [j ].first
341+ }
288342 }
289343 }
290344
@@ -387,6 +441,13 @@ func (conn *Connection) Handle() interface{} {
387441 return conn .opts .Handle
388442}
389443
444+ func (conn * Connection ) cancelFuture (fut * Future , err error ) {
445+ if fut = conn .fetchFuture (fut .requestId ); fut != nil {
446+ fut .SetError (err )
447+ conn .markDone (fut )
448+ }
449+ }
450+
390451func (conn * Connection ) dial () (err error ) {
391452 var connection net.Conn
392453 network := "tcp"
@@ -580,15 +641,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
580641 }
581642 for i := range conn .shard {
582643 conn .shard [i ].buf .Reset ()
583- requests := & conn .shard [i ].requests
584- for pos := range requests {
585- fut := requests [pos ].first
586- requests [pos ].first = nil
587- requests [pos ].last = & requests [pos ].first
588- for fut != nil {
589- fut .SetError (neterr )
590- conn .markDone (fut )
591- fut , fut .next = fut .next , nil
644+ requestsLists := []* [requestsMap ]futureList {& conn .shard [i ].requests , & conn .shard [i ].requestsWithCtx }
645+ for _ , requests := range requestsLists {
646+ for pos := range requests {
647+ requests [pos ].clear (neterr , conn )
592648 }
593649 }
594650 }
@@ -721,7 +777,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
721777 }
722778}
723779
724- func (conn * Connection ) newFuture () (fut * Future ) {
780+ func (conn * Connection ) newFuture (ctx context. Context ) (fut * Future ) {
725781 fut = NewFuture ()
726782 if conn .rlimit != nil && conn .opts .RLimitAction == RLimitDrop {
727783 select {
@@ -736,7 +792,7 @@ func (conn *Connection) newFuture() (fut *Future) {
736792 return
737793 }
738794 }
739- fut .requestId = conn .nextRequestId ()
795+ fut .requestId = conn .nextRequestId (ctx != nil )
740796 shardn := fut .requestId & (conn .opts .Concurrency - 1 )
741797 shard := & conn .shard [shardn ]
742798 shard .rmut .Lock ()
@@ -761,11 +817,20 @@ func (conn *Connection) newFuture() (fut *Future) {
761817 return
762818 }
763819 pos := (fut .requestId / conn .opts .Concurrency ) & (requestsMap - 1 )
764- pair := & shard .requests [pos ]
765- * pair .last = fut
766- pair .last = & fut .next
767- if conn .opts .Timeout > 0 {
768- fut .timeout = time .Since (epoch ) + conn .opts .Timeout
820+ if ctx != nil {
821+ select {
822+ case <- ctx .Done ():
823+ fut .SetError (fmt .Errorf ("context is done" ))
824+ shard .rmut .Unlock ()
825+ return
826+ default :
827+ }
828+ shard .requestsWithCtx [pos ].addFuture (fut )
829+ } else {
830+ shard .requests [pos ].addFuture (fut )
831+ if conn .opts .Timeout > 0 {
832+ fut .timeout = time .Since (epoch ) + conn .opts .Timeout
833+ }
769834 }
770835 shard .rmut .Unlock ()
771836 if conn .rlimit != nil && conn .opts .RLimitAction == RLimitWait {
@@ -785,12 +850,43 @@ func (conn *Connection) newFuture() (fut *Future) {
785850 return
786851}
787852
853+ // This method removes a future from the internal queue if the context
854+ // is "done" before the response is come. Such select logic is inspired
855+ // from this thread: https://groups.google.com/g/golang-dev/c/jX4oQEls3uk
856+ func (conn * Connection ) contextWatchdog (fut * Future , ctx context.Context ) {
857+ select {
858+ case <- fut .done :
859+ default :
860+ select {
861+ case <- ctx .Done ():
862+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
863+ default :
864+ select {
865+ case <- fut .done :
866+ case <- ctx .Done ():
867+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
868+ }
869+ }
870+ }
871+ }
872+
788873func (conn * Connection ) send (req Request ) * Future {
789- fut := conn .newFuture ()
874+ fut := conn .newFuture (req . Ctx () )
790875 if fut .ready == nil {
791876 return fut
792877 }
878+ if req .Ctx () != nil {
879+ select {
880+ case <- req .Ctx ().Done ():
881+ conn .cancelFuture (fut , fmt .Errorf ("context is done" ))
882+ return fut
883+ default :
884+ }
885+ }
793886 conn .putFuture (fut , req )
887+ if req .Ctx () != nil {
888+ go conn .contextWatchdog (fut , req .Ctx ())
889+ }
794890 return fut
795891}
796892
@@ -877,25 +973,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
877973func (conn * Connection ) getFutureImp (reqid uint32 , fetch bool ) * Future {
878974 shard := & conn .shard [reqid & (conn .opts .Concurrency - 1 )]
879975 pos := (reqid / conn .opts .Concurrency ) & (requestsMap - 1 )
880- pair := & shard .requests [pos ]
881- root := & pair .first
882- for {
883- fut := * root
884- if fut == nil {
885- return nil
886- }
887- if fut .requestId == reqid {
888- if fetch {
889- * root = fut .next
890- if fut .next == nil {
891- pair .last = root
892- } else {
893- fut .next = nil
894- }
895- }
896- return fut
897- }
898- root = & fut .next
976+ // futures with even requests id belong to requests list with nil context
977+ if reqid % 2 == 0 {
978+ return shard .requests [pos ].findFuture (reqid , fetch )
979+ } else {
980+ return shard .requestsWithCtx [pos ].findFuture (reqid , fetch )
899981 }
900982}
901983
@@ -984,8 +1066,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) {
9841066 return
9851067}
9861068
987- func (conn * Connection ) nextRequestId () (requestId uint32 ) {
988- return atomic .AddUint32 (& conn .requestId , 1 )
1069+ func (conn * Connection ) nextRequestId (context bool ) (requestId uint32 ) {
1070+ if context {
1071+ return atomic .AddUint32 (& conn .contextRequestId , 2 )
1072+ } else {
1073+ return atomic .AddUint32 (& conn .requestId , 2 )
1074+ }
9891075}
9901076
9911077// Do performs a request asynchronously on the connection.
@@ -1000,6 +1086,15 @@ func (conn *Connection) Do(req Request) *Future {
10001086 return fut
10011087 }
10021088 }
1089+ if req .Ctx () != nil {
1090+ select {
1091+ case <- req .Ctx ().Done ():
1092+ fut := NewFuture ()
1093+ fut .SetError (fmt .Errorf ("context is done" ))
1094+ return fut
1095+ default :
1096+ }
1097+ }
10031098 return conn .send (req )
10041099}
10051100
0 commit comments