@@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
44#[ cfg( windows) ]
55use std:: os:: windows:: io:: { AsRawSocket , RawSocket } ;
66use std:: pin:: Pin ;
7+ #[ cfg( feature = "early-data" ) ]
8+ use std:: task:: Waker ;
79use std:: task:: { Context , Poll } ;
810
911use rustls:: ClientConnection ;
@@ -20,7 +22,7 @@ pub struct TlsStream<IO> {
2022 pub ( crate ) state : TlsState ,
2123
2224 #[ cfg( feature = "early-data" ) ]
23- pub ( crate ) early_waker : Option < std :: task :: Waker > ,
25+ pub ( crate ) early_waker : Option < Waker > ,
2426}
2527
2628impl < IO > TlsStream < IO > {
@@ -152,78 +154,70 @@ where
152154 let mut stream =
153155 Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
154156
155- #[ allow( clippy:: match_single_binding) ]
156- match this. state {
157- #[ cfg( feature = "early-data" ) ]
158- TlsState :: EarlyData ( ref mut pos, ref mut data) => {
159- use std:: io:: Write ;
160-
161- // write early data
162- if let Some ( mut early_data) = stream. session . early_data ( ) {
163- let len = match early_data. write ( buf) {
164- Ok ( n) => n,
165- Err ( err) => return Poll :: Ready ( Err ( err) ) ,
166- } ;
167- if len != 0 {
168- data. extend_from_slice ( & buf[ ..len] ) ;
169- return Poll :: Ready ( Ok ( len) ) ;
170- }
171- }
172-
173- // complete handshake
174- while stream. session . is_handshaking ( ) {
175- ready ! ( stream. handshake( cx) ) ?;
176- }
177-
178- // write early data (fallback)
179- if !stream. session . is_early_data_accepted ( ) {
180- while * pos < data. len ( ) {
181- let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
182- * pos += len;
183- }
184- }
185-
186- // end
187- this. state = TlsState :: Stream ;
188-
189- if let Some ( waker) = this. early_waker . take ( ) {
190- waker. wake ( ) ;
191- }
192-
193- stream. as_mut_pin ( ) . poll_write ( cx, buf)
157+ #[ cfg( feature = "early-data" ) ]
158+ {
159+ let bufs = [ io:: IoSlice :: new ( buf) ] ;
160+ let written = ready ! ( poll_handle_early_data(
161+ & mut this. state,
162+ & mut stream,
163+ & mut this. early_waker,
164+ cx,
165+ & bufs
166+ ) ) ?;
167+ if written != 0 {
168+ return Poll :: Ready ( Ok ( written) ) ;
194169 }
195- _ => stream. as_mut_pin ( ) . poll_write ( cx, buf) ,
196170 }
171+
172+ stream. as_mut_pin ( ) . poll_write ( cx, buf)
197173 }
198174
199- fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
175+ /// Note: that it does not guarantee the final data to be sent.
176+ /// To be cautious, you must manually call `flush`.
177+ fn poll_write_vectored (
178+ self : Pin < & mut Self > ,
179+ cx : & mut Context < ' _ > ,
180+ bufs : & [ io:: IoSlice < ' _ > ] ,
181+ ) -> Poll < io:: Result < usize > > {
200182 let this = self . get_mut ( ) ;
201183 let mut stream =
202184 Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
203185
204186 #[ cfg( feature = "early-data" ) ]
205187 {
206- if let TlsState :: EarlyData ( ref mut pos, ref mut data) = this. state {
207- // complete handshake
208- while stream. session . is_handshaking ( ) {
209- ready ! ( stream. handshake( cx) ) ?;
210- }
188+ let written = ready ! ( poll_handle_early_data(
189+ & mut this. state,
190+ & mut stream,
191+ & mut this. early_waker,
192+ cx,
193+ bufs
194+ ) ) ?;
195+ if written != 0 {
196+ return Poll :: Ready ( Ok ( written) ) ;
197+ }
198+ }
211199
212- // write early data (fallback)
213- if !stream. session . is_early_data_accepted ( ) {
214- while * pos < data. len ( ) {
215- let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
216- * pos += len;
217- }
218- }
200+ stream. as_mut_pin ( ) . poll_write_vectored ( cx, bufs)
201+ }
219202
220- this. state = TlsState :: Stream ;
203+ #[ inline]
204+ fn is_write_vectored ( & self ) -> bool {
205+ true
206+ }
221207
222- if let Some ( waker) = this. early_waker . take ( ) {
223- waker. wake ( ) ;
224- }
225- }
226- }
208+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
209+ let this = self . get_mut ( ) ;
210+ let mut stream =
211+ Stream :: new ( & mut this. io , & mut this. session ) . set_eof ( !this. state . readable ( ) ) ;
212+
213+ #[ cfg( feature = "early-data" ) ]
214+ ready ! ( poll_handle_early_data(
215+ & mut this. state,
216+ & mut stream,
217+ & mut this. early_waker,
218+ cx,
219+ & [ ]
220+ ) ) ?;
227221
228222 stream. as_mut_pin ( ) . poll_flush ( cx)
229223 }
@@ -248,3 +242,69 @@ where
248242 stream. as_mut_pin ( ) . poll_shutdown ( cx)
249243 }
250244}
245+
246+ #[ cfg( feature = "early-data" ) ]
247+ fn poll_handle_early_data < IO > (
248+ state : & mut TlsState ,
249+ stream : & mut Stream < IO , ClientConnection > ,
250+ early_waker : & mut Option < Waker > ,
251+ cx : & mut Context < ' _ > ,
252+ bufs : & [ io:: IoSlice < ' _ > ] ,
253+ ) -> Poll < io:: Result < usize > >
254+ where
255+ IO : AsyncRead + AsyncWrite + Unpin ,
256+ {
257+ if let TlsState :: EarlyData ( pos, data) = state {
258+ use std:: io:: Write ;
259+
260+ // write early data
261+ if let Some ( mut early_data) = stream. session . early_data ( ) {
262+ let mut written = 0 ;
263+
264+ for buf in bufs {
265+ if buf. is_empty ( ) {
266+ continue ;
267+ }
268+
269+ let len = match early_data. write ( buf) {
270+ Ok ( 0 ) => break ,
271+ Ok ( n) => n,
272+ Err ( err) => return Poll :: Ready ( Err ( err) ) ,
273+ } ;
274+
275+ written += len;
276+ data. extend_from_slice ( & buf[ ..len] ) ;
277+
278+ if len < buf. len ( ) {
279+ break ;
280+ }
281+ }
282+
283+ if written != 0 {
284+ return Poll :: Ready ( Ok ( written) ) ;
285+ }
286+ }
287+
288+ // complete handshake
289+ while stream. session . is_handshaking ( ) {
290+ ready ! ( stream. handshake( cx) ) ?;
291+ }
292+
293+ // write early data (fallback)
294+ if !stream. session . is_early_data_accepted ( ) {
295+ while * pos < data. len ( ) {
296+ let len = ready ! ( stream. as_mut_pin( ) . poll_write( cx, & data[ * pos..] ) ) ?;
297+ * pos += len;
298+ }
299+ }
300+
301+ // end
302+ * state = TlsState :: Stream ;
303+
304+ if let Some ( waker) = early_waker. take ( ) {
305+ waker. wake ( ) ;
306+ }
307+ }
308+
309+ Poll :: Ready ( Ok ( 0 ) )
310+ }
0 commit comments