11use std:: io:: { Error , ErrorKind , IoSlice , Result } ;
22use std:: pin:: Pin ;
3- use std:: ptr;
4- use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
3+ use std:: task:: { Context , Poll } ;
54use std:: time:: Duration ;
65
76use bytes:: buf:: BufMut ;
87use ignore_result:: Ignore ;
9- use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
8+ use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
109use tokio:: net:: TcpStream ;
1110use tokio:: { select, time} ;
1211use tracing:: { debug, trace} ;
@@ -26,17 +25,31 @@ use tls::*;
2625use crate :: deadline:: Deadline ;
2726use crate :: endpoint:: { EndpointRef , IterableEndpoints } ;
2827
29- const NOOP_VTABLE : RawWakerVTable =
30- RawWakerVTable :: new ( |_| RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) , |_| { } , |_| { } , |_| { } ) ;
31- const NOOP_WAKER : RawWaker = RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) ;
32-
3328#[ derive( Debug ) ]
3429pub enum Connection {
3530 Raw ( TcpStream ) ,
3631 #[ cfg( feature = "tls" ) ]
3732 Tls ( TlsStream < TcpStream > ) ,
3833}
3934
35+ pub trait AsyncReadToBuf : AsyncReadExt {
36+ async fn read_to_buf ( & mut self , buf : & mut impl BufMut ) -> Result < usize >
37+ where
38+ Self : Unpin , {
39+ let chunk = buf. chunk_mut ( ) ;
40+ let read_to = unsafe { std:: mem:: transmute ( chunk. as_uninit_slice_mut ( ) ) } ;
41+ let n = self . read ( read_to) . await ?;
42+ if n != 0 {
43+ unsafe {
44+ buf. advance_mut ( n) ;
45+ }
46+ }
47+ Ok ( n)
48+ }
49+ }
50+
51+ impl < T > AsyncReadToBuf for T where T : AsyncReadExt { }
52+
4053impl AsyncRead for Connection {
4154 fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
4255 match self . get_mut ( ) {
@@ -56,6 +69,14 @@ impl AsyncWrite for Connection {
5669 }
5770 }
5871
72+ fn poll_write_vectored ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , bufs : & [ IoSlice < ' _ > ] ) -> Poll < Result < usize > > {
73+ match self . get_mut ( ) {
74+ Self :: Raw ( stream) => Pin :: new ( stream) . poll_write_vectored ( cx, bufs) ,
75+ #[ cfg( feature = "tls" ) ]
76+ Self :: Tls ( stream) => Pin :: new ( stream) . poll_write_vectored ( cx, bufs) ,
77+ }
78+ }
79+
5980 fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
6081 match self . get_mut ( ) {
6182 Self :: Raw ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
@@ -73,86 +94,52 @@ impl AsyncWrite for Connection {
7394 }
7495}
7596
76- impl Connection {
77- pub fn new_raw ( stream : TcpStream ) -> Self {
78- Self :: Raw ( stream)
97+ pub struct ConnReader < ' a > {
98+ conn : & ' a mut Connection ,
99+ }
100+
101+ impl AsyncRead for ConnReader < ' _ > {
102+ fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
103+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_read ( cx, buf)
79104 }
105+ }
80106
81- #[ cfg( feature = "tls" ) ]
82- pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
83- Self :: Tls ( stream)
107+ pub struct ConnWriter < ' a > {
108+ conn : & ' a mut Connection ,
109+ }
110+
111+ impl AsyncWrite for ConnWriter < ' _ > {
112+ fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize > > {
113+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_write ( cx, buf)
84114 }
85115
86- pub fn try_write_vectored ( & mut self , bufs : & [ IoSlice < ' _ > ] ) -> Result < usize > {
87- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
88- let mut context = Context :: from_waker ( & waker) ;
89- match Pin :: new ( self ) . poll_write_vectored ( & mut context, bufs) {
90- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
91- Poll :: Ready ( result) => result,
92- }
116+ fn poll_write_vectored ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , bufs : & [ IoSlice < ' _ > ] ) -> Poll < Result < usize > > {
117+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_write_vectored ( cx, bufs)
93118 }
94119
95- pub fn try_read_buf ( & mut self , buf : & mut impl BufMut ) -> Result < usize > {
96- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
97- let mut context = Context :: from_waker ( & waker) ;
98- let chunk = buf. chunk_mut ( ) ;
99- let mut read_buf = unsafe { ReadBuf :: uninit ( chunk. as_uninit_slice_mut ( ) ) } ;
100- match Pin :: new ( self ) . poll_read ( & mut context, & mut read_buf) {
101- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
102- Poll :: Ready ( Err ( err) ) => Err ( err) ,
103- Poll :: Ready ( Ok ( ( ) ) ) => {
104- let n = read_buf. filled ( ) . len ( ) ;
105- unsafe { buf. advance_mut ( n) } ;
106- Ok ( n)
107- } ,
108- }
120+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
121+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_flush ( cx)
109122 }
110123
111- pub async fn readable ( & self ) -> Result < ( ) > {
112- match self {
113- Self :: Raw ( stream) => stream. readable ( ) . await ,
114- #[ cfg( feature = "tls" ) ]
115- Self :: Tls ( stream) => {
116- let ( stream, session) = stream. get_ref ( ) ;
117- if session. wants_read ( ) {
118- stream. readable ( ) . await
119- } else {
120- // plaintext data are available for read
121- std:: future:: ready ( Ok ( ( ) ) ) . await
122- }
123- } ,
124- }
124+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
125+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_shutdown ( cx)
125126 }
127+ }
126128
127- pub async fn writable ( & self ) -> Result < ( ) > {
128- match self {
129- Self :: Raw ( stream) => stream. writable ( ) . await ,
130- #[ cfg( feature = "tls" ) ]
131- Self :: Tls ( stream) => {
132- let ( stream, _session) = stream. get_ref ( ) ;
133- stream. writable ( ) . await
134- } ,
135- }
129+ impl Connection {
130+ pub fn new_raw ( stream : TcpStream ) -> Self {
131+ Self :: Raw ( stream)
136132 }
137133
138- pub fn wants_write ( & self ) -> bool {
139- match self {
140- Self :: Raw ( _) => false ,
141- #[ cfg( feature = "tls" ) ]
142- Self :: Tls ( stream) => {
143- let ( _stream, session) = stream. get_ref ( ) ;
144- session. wants_write ( )
145- } ,
146- }
134+ pub fn split ( & mut self ) -> ( ConnReader < ' _ > , ConnWriter < ' _ > ) {
135+ let reader = ConnReader { conn : self } ;
136+ let writer = ConnWriter { conn : unsafe { std:: ptr:: read ( & reader. conn ) } } ;
137+ ( reader, writer)
147138 }
148139
149- pub fn try_flush ( & mut self ) -> Result < ( ) > {
150- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
151- let mut context = Context :: from_waker ( & waker) ;
152- match Pin :: new ( self ) . poll_flush ( & mut context) {
153- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
154- Poll :: Ready ( result) => result,
155- }
140+ #[ cfg( feature = "tls" ) ]
141+ pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
142+ Self :: Tls ( stream)
156143 }
157144
158145 pub async fn command ( self , cmd : & str ) -> Result < String > {
0 commit comments