@@ -30,8 +30,7 @@ use webpki::DNSNameRef;
3030/// async_std::task::block_on(async {
3131/// let connector = TlsConnector::default();
3232/// let tcp_stream = async_std::net::TcpStream::connect("example.com").await?;
33- /// let handshake = connector.connect("example.com", tcp_stream)?;
34- /// let encrypted_stream = handshake.await?;
33+ /// let encrypted_stream = connector.connect("example.com", tcp_stream).await?;
3534///
3635/// Ok(()) as async_std::io::Result<()>
3736/// });
@@ -83,11 +82,10 @@ impl TlsConnector {
8382 /// Connect to a server. `stream` can be any type implementing `AsyncRead` and `AsyncWrite`,
8483 /// such as TcpStreams or Unix domain sockets.
8584 ///
86- /// The function will return an error if the domain is not of valid format.
87- /// Otherwise, it will return a `Connect` Future, representing the connecting part of a
88- /// Tls handshake. It will resolve when the handshake is over.
85+ /// The function will return a `Connect` Future, representing the connecting part of a Tls
86+ /// handshake. It will resolve when the handshake is over.
8987 #[ inline]
90- pub fn connect < ' a , IO > ( & self , domain : impl AsRef < str > , stream : IO ) -> io :: Result < Connect < IO > >
88+ pub fn connect < ' a , IO > ( & self , domain : impl AsRef < str > , stream : IO ) -> Connect < IO >
9189 where
9290 IO : AsyncRead + AsyncWrite + Unpin ,
9391 {
@@ -96,24 +94,27 @@ impl TlsConnector {
9694
9795 // NOTE: Currently private, exposing ClientSession exposes rusttls
9896 // Early data should be exposed differently
99- fn connect_with < ' a , IO , F > (
100- & self ,
101- domain : impl AsRef < str > ,
102- stream : IO ,
103- f : F ,
104- ) -> io:: Result < Connect < IO > >
97+ fn connect_with < ' a , IO , F > ( & self , domain : impl AsRef < str > , stream : IO , f : F ) -> Connect < IO >
10598 where
10699 IO : AsyncRead + AsyncWrite + Unpin ,
107100 F : FnOnce ( & mut ClientSession ) ,
108101 {
109- let domain = DNSNameRef :: try_from_ascii_str ( domain. as_ref ( ) )
110- . map_err ( |_| io:: Error :: new ( io:: ErrorKind :: InvalidInput , "invalid domain" ) ) ?;
102+ let domain = match DNSNameRef :: try_from_ascii_str ( domain. as_ref ( ) ) {
103+ Ok ( domain) => domain,
104+ Err ( _) => {
105+ return Connect ( ConnectInner :: Error ( Some ( io:: Error :: new (
106+ io:: ErrorKind :: InvalidInput ,
107+ "invalid domain" ,
108+ ) ) ) )
109+ }
110+ } ;
111+
111112 let mut session = ClientSession :: new ( & self . inner , domain) ;
112113 f ( & mut session) ;
113114
114115 #[ cfg( not( feature = "early-data" ) ) ]
115116 {
116- Ok ( Connect ( client:: MidHandshake :: Handshaking (
117+ Connect ( ConnectInner :: Handshake ( client:: MidHandshake :: Handshaking (
117118 client:: TlsStream {
118119 session,
119120 io : stream,
@@ -124,7 +125,7 @@ impl TlsConnector {
124125
125126 #[ cfg( feature = "early-data" ) ]
126127 {
127- Ok ( Connect ( if self . early_data {
128+ Connect ( ConnectInner :: Handshake ( if self . early_data {
128129 client:: MidHandshake :: EarlyData ( client:: TlsStream {
129130 session,
130131 io : stream,
@@ -145,13 +146,23 @@ impl TlsConnector {
145146
146147/// Future returned from `TlsConnector::connect` which will resolve
147148/// once the connection handshake has finished.
148- pub struct Connect < IO > ( client:: MidHandshake < IO > ) ;
149+ pub struct Connect < IO > ( ConnectInner < IO > ) ;
150+
151+ enum ConnectInner < IO > {
152+ Error ( Option < io:: Error > ) ,
153+ Handshake ( client:: MidHandshake < IO > ) ,
154+ }
149155
150156impl < IO : AsyncRead + AsyncWrite + Unpin > Future for Connect < IO > {
151157 type Output = io:: Result < client:: TlsStream < IO > > ;
152158
153159 #[ inline]
154160 fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
155- Pin :: new ( & mut self . 0 ) . poll ( cx)
161+ match self . 0 {
162+ ConnectInner :: Error ( ref mut err) => {
163+ Poll :: Ready ( Err ( err. take ( ) . expect ( "Polled twice after being Ready" ) ) )
164+ }
165+ ConnectInner :: Handshake ( ref mut handshake) => Pin :: new ( handshake) . poll ( cx) ,
166+ }
156167 }
157168}
0 commit comments