11#![ cfg( feature = "early-data" ) ]
22
3- use std:: io:: { self , BufReader , Cursor , Read , Write } ;
4- use std:: net:: { SocketAddr , TcpListener } ;
3+ use std:: io:: { self , BufReader , Cursor } ;
4+ use std:: net:: SocketAddr ;
55use std:: pin:: Pin ;
66use std:: sync:: Arc ;
77use std:: task:: { Context , Poll } ;
8- use std:: thread;
98
109use futures_util:: { future:: Future , ready} ;
11- use rustls:: { self , ClientConfig , RootCertStore , ServerConfig , ServerConnection , Stream } ;
12- use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWriteExt , ReadBuf } ;
13- use tokio:: net:: TcpStream ;
14- use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
10+ use pin_project_lite:: pin_project;
11+ use rustls:: { self , ClientConfig , RootCertStore , ServerConfig } ;
12+ use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , ReadBuf } ;
13+ use tokio:: net:: { TcpListener , TcpStream } ;
14+ use tokio_rustls:: { client, server, TlsAcceptor , TlsConnector } ;
1515
1616struct Read1 < T > ( T ) ;
1717
@@ -33,12 +33,27 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
3333 }
3434}
3535
36+ pin_project ! {
37+ struct TlsStreamEarlyWrapper <IO > {
38+ #[ pin]
39+ inner: server:: TlsStream <IO >
40+ }
41+ }
42+
43+ impl < IO > AsyncRead for TlsStreamEarlyWrapper < IO >
44+ where
45+ IO : AsyncRead + AsyncWrite + Unpin {
46+ fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
47+ return self . project ( ) . inner . poll_read_early_data ( cx, buf) ;
48+ }
49+ }
50+
3651async fn send (
3752 config : Arc < ClientConfig > ,
3853 addr : SocketAddr ,
3954 data : & [ u8 ] ,
4055 vectored : bool ,
41- ) -> io:: Result < ( TlsStream < TcpStream > , Vec < u8 > ) > {
56+ ) -> io:: Result < ( client :: TlsStream < TcpStream > , Vec < u8 > ) > {
4257 let connector = TlsConnector :: from ( config) . early_data ( true ) ;
4358 let stream = TcpStream :: connect ( & addr) . await ?;
4459 let domain = pki_types:: ServerName :: try_from ( "foobar.com" ) . unwrap ( ) ;
@@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
7590 . unwrap ( ) ;
7691 server. max_early_data_size = 8192 ;
7792 let server = Arc :: new ( server) ;
93+ let acceptor = Arc :: new ( TlsAcceptor :: from ( server) ) ;
7894
79- let listener = TcpListener :: bind ( "127.0.0.1:0" ) ?;
95+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
8096 let server_port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
81- thread:: spawn ( move || loop {
82- let ( mut sock, _addr) = listener. accept ( ) . unwrap ( ) ;
97+ tokio:: spawn ( async move {
98+ loop {
99+ let ( mut sock, _addr) = listener. accept ( ) . await . unwrap ( ) ;
100+
101+ let acceptor = acceptor. clone ( ) ;
102+ tokio:: spawn ( async move {
103+ let stream = acceptor. accept ( & mut sock) . await . unwrap ( ) ;
83104
84- let server = Arc :: clone ( & server) ;
85- thread:: spawn ( move || {
86- let mut conn = ServerConnection :: new ( server) . unwrap ( ) ;
87- conn. complete_io ( & mut sock) . unwrap ( ) ;
105+ let mut buf = Vec :: new ( ) ;
106+ let mut stream_wrapper = TlsStreamEarlyWrapper { inner : stream } ;
107+ stream_wrapper. read_to_end ( & mut buf) . await . unwrap ( ) ;
108+ let mut stream = stream_wrapper. inner ;
109+ stream. write_all ( b"EARLY:" ) . await . unwrap ( ) ;
110+ stream. write_all ( & buf) . await . unwrap ( ) ;
88111
89- if let Some ( mut early_data) = conn. early_data ( ) {
90112 let mut buf = Vec :: new ( ) ;
91- early_data. read_to_end ( & mut buf) . unwrap ( ) ;
92- let mut stream = Stream :: new ( & mut conn, & mut sock) ;
93- stream. write_all ( b"EARLY:" ) . unwrap ( ) ;
94- stream. write_all ( & buf) . unwrap ( ) ;
95- }
96-
97- let mut stream = Stream :: new ( & mut conn, & mut sock) ;
98- stream. write_all ( b"LATE:" ) . unwrap ( ) ;
99- loop {
100- let mut buf = [ 0 ; 1024 ] ;
101- let n = stream. read ( & mut buf) . unwrap ( ) ;
102- if n == 0 {
103- conn. send_close_notify ( ) ;
104- conn. complete_io ( & mut sock) . unwrap ( ) ;
105- break ;
106- }
107- stream. write_all ( & buf[ ..n] ) . unwrap ( ) ;
108- }
109- } ) ;
113+ stream. read_to_end ( & mut buf) . await . unwrap ( ) ;
114+ stream. write_all ( b"LATE:" ) . await . unwrap ( ) ;
115+ stream. write_all ( & buf) . await . unwrap ( ) ;
116+
117+ stream. shutdown ( ) . await . unwrap ( ) ;
118+ } ) ;
119+ }
110120 } ) ;
111121
112122 let mut chain = BufReader :: new ( Cursor :: new ( include_str ! ( "end.chain" ) ) ) ;
@@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
125135
126136 let ( io, buf) = send ( config. clone ( ) , addr, b"hello" , vectored) . await ?;
127137 assert ! ( !io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
128- assert_eq ! ( "LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
138+ assert_eq ! ( "EARLY: LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
129139
130140 let ( io, buf) = send ( config, addr, b"world!" , vectored) . await ?;
131141 assert ! ( io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
0 commit comments