@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289289 if err != nil {
290290 return conn , err
291291 }
292+
292293 greeting := conn .Greeting ()
293294 if greeting .Salt == "" {
294295 conn .Close ()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309310 }
310311 }
311312
312- if err := authenticate (conn , d .Auth , d .Username , d .Password ,
313+ if err := authenticate (ctx , conn , d .Auth , d .Username , d .Password ,
313314 conn .Greeting ().Salt ); err != nil {
314315 conn .Close ()
315316 return nil , fmt .Errorf ("failed to authenticate: %w" , err )
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340341 protocolInfo : d .RequiredProtocolInfo ,
341342 }
342343
343- protocolConn .protocolInfo , err = identify (& protocolConn )
344+ protocolConn .protocolInfo , err = identify (ctx , & protocolConn )
344345 if err != nil {
345346 protocolConn .Close ()
346347 return nil , fmt .Errorf ("failed to identify: %w" , err )
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372373 greetingConn := greetingConn {
373374 Conn : conn ,
374375 }
375- version , salt , err := readGreeting (greetingConn )
376+ version , salt , err := readGreeting (ctx , & greetingConn )
376377 if err != nil {
377378 greetingConn .Close ()
378379 return nil , fmt .Errorf ("failed to read greeting: %w" , err )
379380 }
381+
380382 greetingConn .greeting = Greeting {
381383 Version : version ,
382384 Salt : salt ,
@@ -410,31 +412,67 @@ func parseAddress(address string) (string, string) {
410412 return network , address
411413}
412414
415+ // ioWaiter waits in a background until an io operation done or a context
416+ // is expired. It closes the connection and writes a context error into the
417+ // output channel on context expiration.
418+ //
419+ // A user of the helper should close the first output channel after an IO
420+ // operation done and read an error from a second channel to get the result
421+ // of waiting.
422+ func ioWaiter (ctx context.Context , conn Conn ) (chan <- struct {}, <- chan error ) {
423+ doneIO := make (chan struct {})
424+ doneWait := make (chan error , 1 )
425+
426+ go func () {
427+ defer close (doneWait )
428+
429+ select {
430+ case <- ctx .Done ():
431+ conn .Close ()
432+ <- doneIO
433+ doneWait <- ctx .Err ()
434+ case <- doneIO :
435+ doneWait <- nil
436+ }
437+ }()
438+
439+ return doneIO , doneWait
440+ }
441+
413442// readGreeting reads a greeting message.
414- func readGreeting (reader io. Reader ) (string , string , error ) {
443+ func readGreeting (ctx context. Context , conn Conn ) (string , string , error ) {
415444 var version , salt string
416445
446+ doneRead , doneWait := ioWaiter (ctx , conn )
447+
417448 data := make ([]byte , 128 )
418- _ , err := io .ReadFull (reader , data )
449+ _ , err := io .ReadFull (conn , data )
450+
451+ close (doneRead )
452+
419453 if err == nil {
420454 version = bytes .NewBuffer (data [:64 ]).String ()
421455 salt = bytes .NewBuffer (data [64 :108 ]).String ()
422456 }
423457
458+ if waitErr := <- doneWait ; waitErr != nil {
459+ err = waitErr
460+ }
461+
424462 return version , salt , err
425463}
426464
427465// identify sends info about client protocol, receives info
428466// about server protocol in response and stores it in the connection.
429- func identify (conn Conn ) (ProtocolInfo , error ) {
467+ func identify (ctx context. Context , conn Conn ) (ProtocolInfo , error ) {
430468 var info ProtocolInfo
431469
432470 req := NewIdRequest (clientProtocolInfo )
433- if err := writeRequest (conn , req ); err != nil {
471+ if err := writeRequest (ctx , conn , req ); err != nil {
434472 return info , err
435473 }
436474
437- resp , err := readResponse (conn , req )
475+ resp , err := readResponse (ctx , conn , req )
438476 if err != nil {
439477 if resp != nil &&
440478 resp .Header ().Error == iproto .ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +533,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495533}
496534
497535// authenticate authenticates for a connection.
498- func authenticate (c Conn , auth Auth , user string , pass string , salt string ) error {
536+ func authenticate (ctx context. Context , c Conn , auth Auth , user , pass , salt string ) error {
499537 var req Request
500538 var err error
501539
@@ -511,37 +549,73 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511549 return errors .New ("unsupported method " + auth .String ())
512550 }
513551
514- if err = writeRequest (c , req ); err != nil {
552+ if err = writeRequest (ctx , c , req ); err != nil {
515553 return err
516554 }
517- if _ , err = readResponse (c , req ); err != nil {
555+ if _ , err = readResponse (ctx , c , req ); err != nil {
518556 return err
519557 }
520558 return nil
521559}
522560
523561// writeRequest writes a request to the writer.
524- func writeRequest (w writeFlusher , req Request ) error {
562+ func writeRequest (ctx context. Context , conn Conn , req Request ) error {
525563 var packet smallWBuf
526564 err := pack (& packet , msgpack .NewEncoder (& packet ), 0 , req , ignoreStreamId , nil )
527565
528566 if err != nil {
529567 return fmt .Errorf ("pack error: %w" , err )
530568 }
531- if _ , err = w .Write (packet .b ); err != nil {
569+
570+ doneWrite , doneWait := ioWaiter (ctx , conn )
571+
572+ _ , err = conn .Write (packet .b )
573+
574+ close (doneWrite )
575+
576+ if waitErr := <- doneWait ; waitErr != nil {
577+ err = waitErr
578+ }
579+
580+ if err != nil {
532581 return fmt .Errorf ("write error: %w" , err )
533582 }
534- if err = w .Flush (); err != nil {
583+
584+ doneWrite , doneWait = ioWaiter (ctx , conn )
585+
586+ err = conn .Flush ()
587+
588+ close (doneWrite )
589+
590+ if waitErr := <- doneWait ; waitErr != nil {
591+ err = waitErr
592+ }
593+
594+ if err != nil {
535595 return fmt .Errorf ("flush error: %w" , err )
536596 }
597+
598+ if waitErr := <- doneWait ; waitErr != nil {
599+ err = waitErr
600+ }
601+
537602 return err
538603}
539604
540605// readResponse reads a response from the reader.
541- func readResponse (r io. Reader , req Request ) (Response , error ) {
606+ func readResponse (ctx context. Context , conn Conn , req Request ) (Response , error ) {
542607 var lenbuf [packetLengthBytes ]byte
543608
544- respBytes , err := read (r , lenbuf [:])
609+ doneRead , doneWait := ioWaiter (ctx , conn )
610+
611+ respBytes , err := read (conn , lenbuf [:])
612+
613+ close (doneRead )
614+
615+ if waitErr := <- doneWait ; waitErr != nil {
616+ err = waitErr
617+ }
618+
545619 if err != nil {
546620 return nil , fmt .Errorf ("read error: %w" , err )
547621 }
@@ -555,10 +629,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555629 if err != nil {
556630 return nil , fmt .Errorf ("decode response header error: %w" , err )
557631 }
632+
558633 resp , err := req .Response (header , & buf )
559634 if err != nil {
560635 return nil , fmt .Errorf ("creating response error: %w" , err )
561636 }
637+
562638 _ , err = resp .Decode ()
563639 if err != nil {
564640 switch err .(type ) {
@@ -568,5 +644,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568644 return resp , fmt .Errorf ("decode response body error: %w" , err )
569645 }
570646 }
647+
571648 return resp , nil
572649}
0 commit comments