diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d3c02b..527524d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project will be documented in this file. +## [2.0.0] - Unreleased +### Breaking changes +- Switched tokio version to 0.3. +- Removed `AsyncTlsStream` from [tokio-libtls]. `TlsStream` can now be used in all cases where `AsyncTlsStream` could previously. +- Removed `Error` from [tokio-libtls]. Now just use `libtls::TlsError` instead. + ## [1.2.0] - 2020-04-09 ### Added - New with LibreSSL 3.1.0: Support for `TLSv1.3`, diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ed5aacd..1a142b2 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -8,9 +8,9 @@ edition = "2018" [dev-dependencies] futures = "0.3.4" -libtls = { path = "../libtls", version = "1.2.0" } -tokio-libtls = { path = "../tokio-libtls", version = "1.2.0" } -tokio = { version = "0.2.16", features = ["full"] } +libtls = { path = "../libtls", version = "2.0.0" } +tokio-libtls = { path = "../tokio-libtls", version = "2.0.0" } +tokio = { version = "0.3.4", features = ["full"] } [[example]] name = "config" diff --git a/libtls-sys/Cargo.toml b/libtls-sys/Cargo.toml index 8b12c7d..cda09ae 100644 --- a/libtls-sys/Cargo.toml +++ b/libtls-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libtls-sys" -version = "1.2.0" +version = "2.0.0" authors = ["Reyk Floeter "] edition = "2018" license = "ISC" diff --git a/libtls/Cargo.toml b/libtls/Cargo.toml index 47bb707..b1f0660 100644 --- a/libtls/Cargo.toml +++ b/libtls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "libtls" -version = "1.2.0" +version = "2.0.0" authors = ["Reyk Floeter "] edition = "2018" license = "ISC" @@ -14,10 +14,10 @@ keywords = ["crypto", "tls", "ssl", "libressl", "openbsd"] travis-ci = { repository = "reyk/rust-libtls", branch = "master" } [dependencies] -libtls-sys = { path = "../libtls-sys", version = "1.2.0" } +libtls-sys = { path = "../libtls-sys", version = "2.0.0" } [dev-dependencies] rand = "0.7.3" [build-dependencies] -libtls-sys = { path = "../libtls-sys", version = "1.2.0" } +libtls-sys = { path = "../libtls-sys", version = "2.0.0" } diff --git a/tokio-libtls/Cargo.toml b/tokio-libtls/Cargo.toml index 1709724..76c031b 100644 --- a/tokio-libtls/Cargo.toml +++ b/tokio-libtls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-libtls" -version = "1.2.0" +version = "2.0.0" authors = ["Reyk Floeter "] edition = "2018" license = "ISC" @@ -14,10 +14,10 @@ keywords = ["tokio", "tls", "ssl", "libressl", "openbsd"] travis-ci = { repository = "reyk/rust-libtls", branch = "master" } [dependencies] -futures = "0.3.4" -libtls = { path = "../libtls", version = "1.2.0" } -mio = "0.6.21" -tokio = { version = "0.2.16", features = ["io-driver", "tcp", "time"] } +futures = "0.3.8" +libtls = { path = "../libtls", version = "2.0.0" } +mio = "0.7.6" +tokio = { version = "0.3.4", features = ["net", "time"] } [dev-dependencies] -tokio = { version = "0.2.16", features = ["full"] } +tokio = { version = "0.3.4", features = ["full"] } diff --git a/tokio-libtls/src/error.rs b/tokio-libtls/src/error.rs deleted file mode 100644 index 308aa7d..0000000 --- a/tokio-libtls/src/error.rs +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2019, 2020 Reyk Floeter -// -// Permission to use, copy, modify, and distribute this software for any -// purpose with or without fee is hereby granted, provided that the above -// copyright notice and this permission notice appear in all copies. -// -// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -use crate::AsyncTlsStream; -use libtls::error::Error as TlsError; -use std::{error, fmt, io}; - -/// An error returned by [`AsyncTls`]. -/// -/// This error includes the detailed error message of a failed async -/// `libtls` operation. -/// -/// [`AsyncTls`]: ../struct.AsyncTls.html -#[derive(Debug)] -pub enum Error { - /// The connection is readable. - Readable(AsyncTlsStream), - /// The connection is writeable. - Writeable(AsyncTlsStream), - /// The connection is doing a handshake. - Handshake(AsyncTlsStream), - /// A generic error. - Error(TlsError), -} - -/// An error returned by [`AsyncTls`]. -#[deprecated( - since = "1.1.1", - note = "Please use `Error` instead of `AsyncTlsError`" -)] -pub type AsyncTlsError = Error; - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::Readable(_) => write!(f, "Readable I/O in progress"), - Error::Writeable(_) => write!(f, "Writable I/O in progress"), - Error::Handshake(_) => write!(f, "Handshake I/O in progress"), - Error::Error(err) => err.fmt(f), - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - None - } -} - -impl From for Error { - fn from(err: TlsError) -> Self { - Error::Error(err) - } -} - -impl From for Error { - fn from(err: io::Error) -> Self { - err.into() - } -} - -impl From for io::Error { - fn from(err: Error) -> Self { - io::Error::new(io::ErrorKind::Other, err) - } -} diff --git a/tokio-libtls/src/lib.rs b/tokio-libtls/src/lib.rs index cf4fdcb..201b167 100644 --- a/tokio-libtls/src/lib.rs +++ b/tokio-libtls/src/lib.rs @@ -60,18 +60,14 @@ )] #![warn(missing_docs)] -/// Error handling. -pub mod error; - /// A "prelude" for crates using the `tokio-libtls` crate. pub mod prelude; -use error::Error; +use futures::ready; use libtls::{config::Config, error::Error as TlsError, tls::Tls}; -use mio::{event::Evented, unix::EventedFd, PollOpt, Ready, Token}; use prelude::*; use std::{ - io::{self, Read, Write}, + io, net::ToSocketAddrs, ops::{Deref, DerefMut}, os::unix::io::{AsRawFd, RawFd}, @@ -80,32 +76,20 @@ use std::{ time::Duration, }; use tokio::{ - io::{AsyncRead, AsyncWrite, PollEvented}, + io::{AsyncRead, AsyncWrite, ReadBuf}, net::{TcpListener, TcpStream}, time::timeout, }; -macro_rules! try_async_tls { - ($call: expr) => { - match $call { - Ok(size) => Poll::Ready(Ok(size)), - Err(err) => { - let err: io::Error = err.into(); - if err.kind() == io::ErrorKind::WouldBlock { - Poll::Pending - } else { - Poll::Ready(Err(err)) - } - } - } - }; -} - /// Wrapper for async I/O operations with `Tls`. #[derive(Debug)] pub struct TlsStream { - tls: Tls, - tcp: TcpStream, + /// The underlying `Tls` instance for this `TlsStream`. + pub tls: Tls, + /// The underlying `TcpStream` for this `TlsStream`. + /// + /// This can be used to poll the readable and writable status of the socket, if necessary. + pub tcp: TcpStream, } impl TlsStream { @@ -113,6 +97,31 @@ impl TlsStream { pub fn new(tls: Tls, tcp: TcpStream) -> Self { Self { tls, tcp } } + + /// Attempts an IO action, handling `TLS_WANT_POLLIN` and `TLS_WANT_POLLOUT`. + /// + /// This function calls `f` repeatedly, rescheduling this task whenever it returns + /// `TLS_WANT_POLLIN` or `TLS_WANT_POLLOUT`. + fn poll_io( + &mut self, + cx: &mut Context<'_>, + mut f: impl FnMut(&mut Tls) -> Result, + ) -> Poll> { + loop { + match f(&mut self.tls) { + Err(err) => return Poll::Ready(Err(err.into())), + Ok(value) => { + if value == libtls::TLS_WANT_POLLIN as isize { + ready!(self.tcp.poll_read_ready(cx)?); + } else if value == libtls::TLS_WANT_POLLOUT as isize { + ready!(self.tcp.poll_write_ready(cx)?); + } else { + return Poll::Ready(Ok(value)); + } + } + } + } + } } impl Deref for TlsStream { @@ -154,75 +163,49 @@ impl io::Write for TlsStream { impl AsyncRead for TlsStream { fn poll_read( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - try_async_tls!(self.tls.read(buf)) + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // libtls should correctly fill the uninintialized buffer, so this unsafe is okay + unsafe { + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] as *mut [u8]); + let n = ready!(self.poll_io(cx, |tls| tls.tls_read(b))?) as usize; + buf.assume_init(n); + buf.advance(n); + Poll::Ready(Ok(())) + } } } impl AsyncWrite for TlsStream { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { - try_async_tls!(self.tls.write(buf)) + ) -> Poll> { + let n = ready!(self.poll_io(cx, |tls| tls.tls_write(buf))?) as usize; + Poll::Ready(Ok(n)) } - fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - try_async_tls!(self.tls.close()).map(|_| Ok(())) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { - try_async_tls!(self.tls.close()).map(|_| Ok(())) - } -} - -impl Evented for TlsStream { - fn register( - &self, - poll: &mio::Poll, - token: Token, - interest: Ready, - opts: PollOpt, - ) -> io::Result<()> { - match EventedFd(&self.as_raw_fd()).register(poll, token, interest, opts) { - Err(ref err) if err.kind() == io::ErrorKind::AlreadyExists => { - self.reregister(poll, token, interest, opts) - } - Err(err) => Err(err), - Ok(_) => Ok(()), - } - } - - fn reregister( - &self, - poll: &mio::Poll, - token: Token, - interest: Ready, - opts: PollOpt, - ) -> io::Result<()> { - EventedFd(&self.as_raw_fd()).reregister(poll, token, interest, opts) - } - - fn deregister(&self, poll: &mio::Poll) -> io::Result<()> { - EventedFd(&self.as_raw_fd()).deregister(poll) + ready!(self.poll_io(cx, |tls| tls.tls_close())?); + Poll::Ready(Ok(())) } } unsafe impl Send for TlsStream {} unsafe impl Sync for TlsStream {} -/// Pollable wrapper for async I/O operations with `Tls`. -pub type AsyncTlsStream = PollEvented; - /// Async `Tls` struct. pub struct AsyncTls { - inner: Option>, + stream: Option, } impl AsyncTls { @@ -232,7 +215,7 @@ impl AsyncTls { tcp: TcpStream, config: &Config, options: Option, - ) -> io::Result { + ) -> io::Result { accept_stream(tcp, config, options).await } @@ -242,7 +225,7 @@ impl AsyncTls { tcp: TcpStream, config: &Config, options: Option, - ) -> io::Result { + ) -> io::Result { connect_stream(tcp, config, options).await } @@ -252,63 +235,21 @@ impl AsyncTls { host: &str, config: &Config, options: Option, - ) -> io::Result { + ) -> io::Result { connect(host, config, options).await } } impl Future for AsyncTls { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let inner = self - .inner - .take() - .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cannot take inner"))?; - match inner { - Ok(tls) => { - cx.waker().wake_by_ref(); - Poll::Ready(Ok(tls)) - } - Err(Error::Readable(stream)) => { - self.inner = match stream.poll_read_ready(cx, Ready::readable()) { - Poll::Ready(_) => Some(Err(Error::Handshake(stream))), - _ => Some(Err(Error::Handshake(stream))), - }; - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(Error::Writeable(stream)) => { - self.inner = match stream.poll_write_ready(cx) { - Poll::Ready(_) => Some(Err(Error::Handshake(stream))), - _ => Some(Err(Error::Writeable(stream))), - }; - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(Error::Handshake(mut stream)) => { - let tls = &mut *stream.get_mut(); - let res = match tls.tls_handshake() { - Ok(res) => { - if res == libtls::TLS_WANT_POLLIN as isize { - Err(Error::Readable(stream)) - } else if res == libtls::TLS_WANT_POLLOUT as isize { - Err(Error::Writeable(stream)) - } else { - Ok(stream) - } - } - Err(err) => Err(err.into()), - }; - self.inner = Some(res); - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(Error::Error(TlsError::IoError(err))) => Poll::Ready(Err(err)), - Err(Error::Error(err)) => { - Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err.to_string()))) - } - } + let stream = self + .stream + .as_mut() + .expect("AsyncTls::poll called again after returning Poll::Ready"); + ready!(stream.poll_io(cx, |tls| tls.tls_handshake())?); + Poll::Ready(Ok(self.stream.take().unwrap())) } } @@ -320,7 +261,7 @@ pub async fn accept( listener: &mut TcpListener, config: &Config, options: Option, -) -> io::Result { +) -> io::Result { let options = options.unwrap_or_else(Options::new); let (tcp, _) = listener.accept().await?; @@ -328,10 +269,9 @@ pub async fn accept( server.configure(config)?; let client = server.accept_raw_fd(&tcp)?; - let async_tls = TlsStream::new(client, tcp); - let stream = PollEvented::new(async_tls)?; + let stream = TlsStream::new(client, tcp); let fut = AsyncTls { - inner: Some(Err(Error::Readable(stream))), + stream: Some(stream), }; // Accept with an optional timeout for the TLS handshake. @@ -351,17 +291,16 @@ pub async fn accept_stream( tcp: TcpStream, config: &Config, options: Option, -) -> io::Result { +) -> io::Result { let options = options.unwrap_or_else(Options::new); let mut server = Tls::server()?; server.configure(config)?; let client = server.accept_raw_fd(&tcp)?; - let async_tls = TlsStream::new(client, tcp); - let stream = PollEvented::new(async_tls)?; + let stream = TlsStream::new(client, tcp); let fut = AsyncTls { - inner: Some(Err(Error::Readable(stream))), + stream: Some(stream), }; // Accept with an optional timeout for the TLS handshake. @@ -381,7 +320,7 @@ pub async fn connect_stream( tcp: TcpStream, config: &Config, options: Option, -) -> io::Result { +) -> io::Result { let options = options.unwrap_or_else(Options::new); let servername = match options.servername { Some(name) => name, @@ -393,10 +332,9 @@ pub async fn connect_stream( tls.configure(config)?; tls.connect_raw_fd(&tcp, &servername)?; - let async_tls = TlsStream::new(tls, tcp); - let stream = PollEvented::new(async_tls)?; + let stream = TlsStream::new(tls, tcp); let fut = AsyncTls { - inner: Some(Err(Error::Readable(stream))), + stream: Some(stream), }; // Connect with an optional timeout for the TLS handshake. @@ -416,7 +354,7 @@ pub async fn connect( host: &str, config: &Config, options: Option, -) -> io::Result { +) -> io::Result { let mut options = options.unwrap_or_else(Options::new); // Remove _last_ colon (to satisfy the IPv6 form, e.g. [::1]::443).