@@ -816,8 +816,11 @@ mod tests {
816816 use http_body:: Body ;
817817 use http_body_util:: { BodyExt , Empty , Full } ;
818818 use hyper:: { body, body:: Bytes , client, service:: service_fn} ;
819- use std:: { convert:: Infallible , error:: Error as StdError , net:: SocketAddr } ;
820- use tokio:: net:: { TcpListener , TcpStream } ;
819+ use std:: { convert:: Infallible , error:: Error as StdError , net:: SocketAddr , time:: Duration } ;
820+ use tokio:: {
821+ net:: { TcpListener , TcpStream } ,
822+ pin,
823+ } ;
821824
822825 const BODY : & [ u8 ] = b"Hello, world!" ;
823826
@@ -871,6 +874,40 @@ mod tests {
871874 assert_eq ! ( body, BODY ) ;
872875 }
873876
877+ #[ cfg( not( miri) ) ]
878+ #[ tokio:: test]
879+ async fn graceful_shutdown ( ) {
880+ let listener = TcpListener :: bind ( SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) )
881+ . await
882+ . unwrap ( ) ;
883+
884+ let listener_addr = listener. local_addr ( ) . unwrap ( ) ;
885+
886+ // Spawn the task in background so that we can connect there
887+ let listen_task = tokio:: spawn ( async move { listener. accept ( ) . await . unwrap ( ) } ) ;
888+ // Only connect a stream, do not send headers or anything
889+ let _stream = TcpStream :: connect ( listener_addr) . await . unwrap ( ) ;
890+
891+ let ( stream, _) = listen_task. await . unwrap ( ) ;
892+ let stream = TokioIo :: new ( stream) ;
893+ let builder = auto:: Builder :: new ( TokioExecutor :: new ( ) ) ;
894+ let connection = builder. serve_connection ( stream, service_fn ( hello) ) ;
895+
896+ pin ! ( connection) ;
897+
898+ connection. as_mut ( ) . graceful_shutdown ( ) ;
899+
900+ let connection_error = tokio:: time:: timeout ( Duration :: from_millis ( 200 ) , connection)
901+ . await
902+ . expect ( "Connection should have finished in a timely manner after graceful shutdown." )
903+ . expect_err ( "Connection should have been interrupted." ) ;
904+
905+ let connection_error = connection_error
906+ . downcast_ref :: < std:: io:: Error > ( )
907+ . expect ( "The error should have been `std::io::Error`." ) ;
908+ assert_eq ! ( connection_error. kind( ) , std:: io:: ErrorKind :: Interrupted ) ;
909+ }
910+
874911 async fn connect_h1 < B > ( addr : SocketAddr ) -> client:: conn:: http1:: SendRequest < B >
875912 where
876913 B : Body + Send + ' static ,
0 commit comments