@@ -18,7 +18,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}
1818use std:: ops:: Neg ;
1919use std:: os:: unix:: prelude:: * ;
2020use std:: sync:: atomic:: { AtomicBool , Ordering , ATOMIC_BOOL_INIT } ;
21- use std:: time:: Duration ;
21+ use std:: time:: { Duration , Instant } ;
2222
2323use libc:: { self , c_void, c_int, sockaddr_in, sockaddr_storage, sockaddr_in6} ;
2424use libc:: { sockaddr, socklen_t, AF_INET , AF_INET6 , ssize_t} ;
@@ -118,6 +118,67 @@ impl Socket {
118118 }
119119 }
120120
121+ pub fn connect_timeout ( & self , addr : & SocketAddr , timeout : Duration ) -> io:: Result < ( ) > {
122+ self . set_nonblocking ( true ) ?;
123+ let r = self . connect ( addr) ;
124+ self . set_nonblocking ( false ) ?;
125+
126+ match r {
127+ Ok ( ( ) ) => return Ok ( ( ) ) ,
128+ // there's no io::ErrorKind conversion registered for EINPROGRESS :(
129+ Err ( ref e) if e. raw_os_error ( ) == Some ( libc:: EINPROGRESS ) => { }
130+ Err ( e) => return Err ( e) ,
131+ }
132+
133+ let mut pollfd = libc:: pollfd {
134+ fd : self . fd ,
135+ events : libc:: POLLOUT ,
136+ revents : 0 ,
137+ } ;
138+
139+ if timeout. as_secs ( ) == 0 && timeout. subsec_nanos ( ) == 0 {
140+ return Err ( io:: Error :: new ( io:: ErrorKind :: InvalidInput ,
141+ "cannot set a 0 duration timeout" ) ) ;
142+ }
143+
144+ let start = Instant :: now ( ) ;
145+
146+ loop {
147+ let elapsed = start. elapsed ( ) ;
148+ if elapsed >= timeout {
149+ return Err ( io:: Error :: new ( io:: ErrorKind :: TimedOut , "connection timed out" ) ) ;
150+ }
151+
152+ let timeout = timeout - elapsed;
153+ let mut timeout = timeout. as_secs ( )
154+ . saturating_mul ( 1_000 )
155+ . saturating_add ( timeout. subsec_nanos ( ) as u64 / 1_000_000 ) ;
156+ if timeout == 0 {
157+ timeout = 1 ;
158+ }
159+
160+ let timeout = cmp:: min ( timeout, c_int:: max_value ( ) as u64 ) as c_int ;
161+
162+ match unsafe { libc:: poll ( & mut pollfd, 1 , timeout) } {
163+ -1 => {
164+ let err = io:: Error :: last_os_error ( ) ;
165+ if err. kind ( ) != io:: ErrorKind :: Interrupted {
166+ return Err ( err) ;
167+ }
168+ }
169+ 0 => return Err ( io:: Error :: new ( io:: ErrorKind :: TimedOut , "connection timed out" ) ) ,
170+ _ => {
171+ if pollfd. revents & libc:: POLLOUT == 0 {
172+ if let Some ( e) = self . take_error ( ) ? {
173+ return Err ( e) ;
174+ }
175+ }
176+ return Ok ( ( ) ) ;
177+ }
178+ }
179+ }
180+ }
181+
121182 pub fn local_addr ( & self ) -> io:: Result < SocketAddr > {
122183 unsafe {
123184 let mut storage: libc:: sockaddr_storage = mem:: zeroed ( ) ;
0 commit comments