@@ -9,7 +9,7 @@ use std::{
99use tokio:: { io:: AsyncWrite , net:: TcpStream } ;
1010
1111use crate :: {
12- error:: { ErrorKind , Result } ,
12+ error:: { Error , ErrorKind , Result } ,
1313 options:: ServerAddress ,
1414 runtime,
1515} ;
@@ -78,7 +78,12 @@ async fn tcp_try_connect(address: &SocketAddr) -> Result<TcpStream> {
7878}
7979
8080async fn tcp_connect ( address : & ServerAddress ) -> Result < TcpStream > {
81- let mut socket_addrs: Vec < _ > = runtime:: resolve_address ( address) . await ?. collect ( ) ;
81+ // "Happy Eyeballs": try addresses in parallel, interleaving IPv6 and IPv4, preferring IPv6.
82+ // Based on the implementation in https://codeberg.org/KMK/happy-eyeballs.
83+ let ( addrs_v6, addrs_v4) : ( Vec < _ > , Vec < _ > ) = runtime:: resolve_address ( address)
84+ . await ?
85+ . partition ( |a| matches ! ( a, SocketAddr :: V6 ( _) ) ) ;
86+ let socket_addrs = interleave ( addrs_v6, addrs_v4) ;
8287
8388 if socket_addrs. is_empty ( ) {
8489 return Err ( ErrorKind :: DnsResolve {
@@ -87,19 +92,58 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
8792 . into ( ) ) ;
8893 }
8994
90- // After considering various approaches, we decided to do what other drivers do, namely try
91- // each of the addresses in sequence with a preference for IPv4.
92- socket_addrs. sort_by_key ( |addr| if addr. is_ipv4 ( ) { 0 } else { 1 } ) ;
95+ fn handle_join (
96+ result : std:: result:: Result < Result < TcpStream > , tokio:: task:: JoinError > ,
97+ ) -> Result < TcpStream > {
98+ match result {
99+ Ok ( r) => r,
100+ // JoinError indicates the task was cancelled or paniced, which should never happen
101+ // here.
102+ Err ( e) => Err ( Error :: internal ( format ! ( "TCP connect task failure: {}" , e) ) ) ,
103+ }
104+ }
105+
106+ static CONNECTION_ATTEMPT_DELAY : Duration = Duration :: from_millis ( 250 ) ;
93107
108+ // Race connections
109+ let mut attempts = tokio:: task:: JoinSet :: new ( ) ;
94110 let mut connect_error = None ;
111+ ' spawn: for a in socket_addrs {
112+ attempts. spawn ( async move { tcp_try_connect ( & a) . await } ) ;
113+ let sleep = tokio:: time:: sleep ( CONNECTION_ATTEMPT_DELAY ) ;
114+ tokio:: pin!( sleep) ; // required for select!
115+ while !attempts. is_empty ( ) {
116+ tokio:: select! {
117+ biased;
118+ connect_res = attempts. join_next( ) => {
119+ match connect_res. map( handle_join) {
120+ // The gating `while !attempts.is_empty()` should mean this never happens.
121+ None => return Err ( Error :: internal( "empty TCP connect task set" ) ) ,
122+ // A connection succeeded, return it. The JoinSet will cancel remaining tasks on drop.
123+ Some ( Ok ( cnx) ) => return Ok ( cnx) ,
124+ // A connection failed. Remember the error and wait for any other remaining attempts.
125+ Some ( Err ( e) ) => {
126+ connect_error. get_or_insert( e) ;
127+ } ,
128+ }
129+ }
130+ // CONNECTION_ATTEMPT_DELAY expired, spawn a new connection attempt.
131+ _ = & mut sleep => continue ' spawn
132+ }
133+ }
134+ }
95135
96- for address in & socket_addrs {
97- connect_error = match tcp_try_connect ( address) . await {
98- Ok ( stream) => return Ok ( stream) ,
99- Err ( err) => Some ( err) ,
100- } ;
136+ // No more address to try. Drain the attempts until one succeeds.
137+ while let Some ( result) = attempts. join_next ( ) . await {
138+ match handle_join ( result) {
139+ Ok ( cnx) => return Ok ( cnx) ,
140+ Err ( e) => {
141+ connect_error. get_or_insert ( e) ;
142+ }
143+ }
101144 }
102145
146+ // All attempts failed. Return the first error.
103147 Err ( connect_error. unwrap_or_else ( || {
104148 ErrorKind :: Internal {
105149 message : "connecting to all DNS results failed but no error reported" . to_string ( ) ,
@@ -108,6 +152,17 @@ async fn tcp_connect(address: &ServerAddress) -> Result<TcpStream> {
108152 } ) )
109153}
110154
155+ fn interleave < T > ( left : Vec < T > , right : Vec < T > ) -> Vec < T > {
156+ let mut out = Vec :: with_capacity ( left. len ( ) + right. len ( ) ) ;
157+ let ( mut left, mut right) = ( left. into_iter ( ) , right. into_iter ( ) ) ;
158+ while let Some ( a) = left. next ( ) {
159+ out. push ( a) ;
160+ std:: mem:: swap ( & mut left, & mut right) ;
161+ }
162+ out. extend ( right) ;
163+ out
164+ }
165+
111166impl tokio:: io:: AsyncRead for AsyncStream {
112167 fn poll_read (
113168 mut self : Pin < & mut Self > ,
0 commit comments