@@ -10,7 +10,7 @@ use std::thread;
1010use futures_util:: { future:: Future , ready} ;
1111use rustls:: pki_types:: ServerName ;
1212use rustls:: { self , ClientConfig , ServerConnection , Stream } ;
13- use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWriteExt , ReadBuf } ;
13+ use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , ReadBuf } ;
1414use tokio:: net:: TcpStream ;
1515use tokio_rustls:: client:: TlsStream ;
1616use tokio_rustls:: TlsConnector ;
@@ -35,14 +35,15 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
3535 }
3636}
3737
38- async fn send (
38+ async fn send < S : AsyncRead + AsyncWrite + Unpin > (
3939 config : Arc < ClientConfig > ,
4040 addr : SocketAddr ,
41+ wrapper : impl Fn ( TcpStream ) -> S ,
4142 data : & [ u8 ] ,
4243 vectored : bool ,
43- ) -> io:: Result < ( TlsStream < TcpStream > , Vec < u8 > ) > {
44+ ) -> io:: Result < ( TlsStream < S > , Vec < u8 > ) > {
4445 let connector = TlsConnector :: from ( config) . early_data ( true ) ;
45- let stream = TcpStream :: connect ( & addr) . await ?;
46+ let stream = wrapper ( TcpStream :: connect ( & addr) . await ?) ;
4647 let domain = ServerName :: try_from ( "foobar.com" ) . unwrap ( ) ;
4748
4849 let mut stream = connector. connect ( domain, stream) . await ?;
@@ -58,15 +59,23 @@ async fn send(
5859
5960#[ tokio:: test]
6061async fn test_0rtt ( ) -> io:: Result < ( ) > {
61- test_0rtt_impl ( false ) . await
62+ test_0rtt_impl ( |s| s , false ) . await
6263}
6364
6465#[ tokio:: test]
6566async fn test_0rtt_vectored ( ) -> io:: Result < ( ) > {
66- test_0rtt_impl ( true ) . await
67+ test_0rtt_impl ( |s| s , true ) . await
6768}
6869
69- async fn test_0rtt_impl ( vectored : bool ) -> io:: Result < ( ) > {
70+ #[ tokio:: test]
71+ async fn test_0rtt_vectored_flush_pending ( ) -> io:: Result < ( ) > {
72+ test_0rtt_impl ( utils:: FlushWrapper :: new, false ) . await
73+ }
74+
75+ async fn test_0rtt_impl < S : AsyncRead + AsyncWrite + Unpin > (
76+ wrapper : impl Fn ( TcpStream ) -> S ,
77+ vectored : bool ,
78+ ) -> io:: Result < ( ) > {
7079 let ( mut server, mut client) = utils:: make_configs ( ) ;
7180 server. max_early_data_size = 8192 ;
7281 let server = Arc :: new ( server) ;
@@ -108,11 +117,11 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
108117 let client = Arc :: new ( client) ;
109118 let addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , server_port) ) ;
110119
111- let ( io, buf) = send ( client. clone ( ) , addr, b"hello" , vectored) . await ?;
120+ let ( io, buf) = send ( client. clone ( ) , addr, & wrapper , b"hello" , vectored) . await ?;
112121 assert ! ( !io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
113122 assert_eq ! ( "LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
114123
115- let ( io, buf) = send ( client, addr, b"world!" , vectored) . await ?;
124+ let ( io, buf) = send ( client, addr, wrapper , b"world!" , vectored) . await ?;
116125 assert ! ( io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
117126 assert_eq ! ( "EARLY:world!LATE:" , String :: from_utf8_lossy( & buf) ) ;
118127
0 commit comments