From b57bd9646ed379aa3581beb0d6ae287757e260c6 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 20 Oct 2025 10:23:44 -0400 Subject: [PATCH] fix(http2): fix internals of HTTP/2 CONNECT upgrades This refactors the way hyper handles HTTP/2 CONNECT / Extended CONNECT. Before, an uninhabited enum was used to try to prevent sending of the `Buf` type once the STREAM had been upgraded. However, the way it was originally written was incorrect, and will eventually have compilation issues. The change here is to spawn an extra task and use a channel to bridge the IO operations of the `Upgraded` object to be `Cursor` buffers in the new task. ref: https://github.com/rust-lang/rust/issues/147588 --- src/client/conn/http2.rs | 38 ------ src/client/dispatch.rs | 7 +- src/proto/h2/client.rs | 41 +++--- src/proto/h2/mod.rs | 193 +-------------------------- src/proto/h2/server.rs | 41 +++--- src/proto/h2/upgrade.rs | 280 +++++++++++++++++++++++++++++++++++++++ src/rt/bounds.rs | 70 +++++++--- 7 files changed, 382 insertions(+), 288 deletions(-) create mode 100644 src/proto/h2/upgrade.rs diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index 356d909dd9..0efaabe41e 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -640,44 +640,6 @@ mod tests { } } - #[tokio::test] - #[ignore] // only compilation is checked - async fn not_send_not_sync_executor_of_send_futures() { - #[derive(Clone)] - struct TokioExecutor { - // !Send, !Sync - _x: std::marker::PhantomData>, - } - - impl crate::rt::Executor for TokioExecutor - where - F: std::future::Future + 'static + Send, - F::Output: Send + 'static, - { - fn execute(&self, fut: F) { - tokio::task::spawn(fut); - } - } - - #[allow(unused)] - async fn run(io: impl crate::rt::Read + crate::rt::Write + Send + Unpin + 'static) { - let (_sender, conn) = - crate::client::conn::http2::handshake::<_, _, http_body_util::Empty>( - TokioExecutor { - _x: Default::default(), - }, - io, - ) - .await - .unwrap(); - - tokio::task::spawn_local(async move { - // can't use spawn here because when executor is !Send - conn.await.unwrap(); - }); - } - } - #[tokio::test] #[ignore] // only compilation is checked async fn send_not_sync_executor_of_send_futures() { diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 412f7af52c..2cbc1b9225 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -325,22 +325,23 @@ impl TrySendError { #[cfg(feature = "http2")] pin_project! { - pub struct SendWhen + pub struct SendWhen where B: Body, B: 'static, { #[pin] - pub(crate) when: ResponseFutMap, + pub(crate) when: ResponseFutMap, #[pin] pub(crate) call_back: Option, Response>>, } } #[cfg(feature = "http2")] -impl Future for SendWhen +impl Future for SendWhen where B: Body + 'static, + E: crate::rt::bounds::Http2UpgradedExec, { type Output = (); diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 3860a5afaf..455c70980c 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -18,7 +18,7 @@ use http::{Method, StatusCode}; use pin_project_lite::pin_project; use super::ping::{Ponger, Recorder}; -use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; +use super::{ping, PipeToSendStream, SendBuf}; use crate::body::{Body, Incoming as IncomingBody}; use crate::client::dispatch::{Callback, SendWhen, TrySendError}; use crate::common::either::Either; @@ -26,9 +26,8 @@ use crate::common::io::Compat; use crate::common::time::Time; use crate::ext::Protocol; use crate::headers; -use crate::proto::h2::UpgradedSendStream; use crate::proto::Dispatched; -use crate::rt::bounds::Http2ClientConnExec; +use crate::rt::bounds::{Http2ClientConnExec, Http2UpgradedExec}; use crate::upgrade::Upgraded; use crate::{Request, Response}; use h2::client::ResponseFuture; @@ -151,7 +150,7 @@ where T: Read + Write + Unpin, B: Body + 'static, B::Data: Send + 'static, - E: Http2ClientConnExec + Unpin, + E: Http2ClientConnExec + Clone + Unpin, B::Error: Into>, { let (h2_tx, mut conn) = new_builder(config) @@ -357,7 +356,7 @@ where pin_project! { #[project = H2ClientFutureProject] - pub enum H2ClientFuture + pub enum H2ClientFuture where B: http_body::Body, B: 'static, @@ -372,7 +371,7 @@ pin_project! { }, Send { #[pin] - send_when: SendWhen, + send_when: SendWhen, }, Task { #[pin] @@ -381,11 +380,12 @@ pin_project! { } } -impl Future for H2ClientFuture +impl Future for H2ClientFuture where B: http_body::Body + 'static, B::Error: Into>, T: Read + Write + Unpin, + E: Http2UpgradedExec, { type Output = (); @@ -484,7 +484,7 @@ impl ClientTask where B: Body + 'static + Unpin, B::Data: Send, - E: Http2ClientConnExec + Unpin, + E: Http2ClientConnExec + Clone + Unpin, B::Error: Into>, T: Read + Write + Unpin, { @@ -529,6 +529,7 @@ where fut: f.fut, ping: Some(ping), send_stream: Some(send_stream), + exec: self.executor.clone(), }, call_back: Some(f.cb), }, @@ -537,28 +538,29 @@ where } pin_project! { - pub(crate) struct ResponseFutMap + pub(crate) struct ResponseFutMap where B: Body, B: 'static, { #[pin] fut: ResponseFuture, - #[pin] ping: Option, #[pin] send_stream: Option::Data>>>>, + exec: E, } } -impl Future for ResponseFutMap +impl Future for ResponseFutMap where B: Body + 'static, + E: Http2UpgradedExec, { type Output = Result, (crate::Error, Option>)>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); let result = ready!(this.fut.poll(cx)); @@ -585,13 +587,10 @@ where let mut res = Response::from_parts(parts, IncomingBody::empty()); let (pending, on_upgrade) = crate::upgrade::pending(); - let io = H2Upgraded { - ping, - send_stream: unsafe { UpgradedSendStream::new(send_stream) }, - recv_stream, - buf: Bytes::new(), - }; - let upgraded = Upgraded::new(io, Bytes::new()); + + let (h2_up, up_task) = super::upgrade::pair(send_stream, recv_stream, ping); + self.exec.execute_upgrade(up_task); + let upgraded = Upgraded::new(h2_up, Bytes::new()); pending.fulfill(upgraded); res.extensions_mut().insert(on_upgrade); @@ -620,7 +619,7 @@ where B: Body + 'static + Unpin, B::Data: Send, B::Error: Into>, - E: Http2ClientConnExec + Unpin, + E: Http2ClientConnExec + Clone + Unpin, T: Read + Write + Unpin, { type Output = crate::Result; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index 73bda4c224..e6f40f3467 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,22 +1,20 @@ use std::error::Error as StdError; use std::future::Future; use std::io::{Cursor, IoSlice}; -use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -use bytes::{Buf, Bytes}; +use bytes::Buf; use futures_core::ready; -use h2::{Reason, RecvStream, SendStream}; +use h2::SendStream; use http::header::{HeaderName, CONNECTION, TE, TRANSFER_ENCODING, UPGRADE}; use http::HeaderMap; use pin_project_lite::pin_project; use crate::body::Body; -use crate::proto::h2::ping::Recorder; -use crate::rt::{Read, ReadBufCursor, Write}; pub(crate) mod ping; +pub(crate) mod upgrade; cfg_client! { pub(crate) mod client; @@ -259,188 +257,3 @@ impl Buf for SendBuf { } } } - -struct H2Upgraded -where - B: Buf, -{ - ping: Recorder, - send_stream: UpgradedSendStream, - recv_stream: RecvStream, - buf: Bytes, -} - -impl Read for H2Upgraded -where - B: Buf, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut read_buf: ReadBufCursor<'_>, - ) -> Poll> { - if self.buf.is_empty() { - self.buf = loop { - match ready!(self.recv_stream.poll_data(cx)) { - None => return Poll::Ready(Ok(())), - Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { - continue - } - Some(Ok(buf)) => { - self.ping.record_data(buf.len()); - break buf; - } - Some(Err(e)) => { - return Poll::Ready(match e.reason() { - Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), - Some(Reason::STREAM_CLOSED) => { - Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) - } - _ => Err(h2_to_io_error(e)), - }) - } - } - }; - } - let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); - read_buf.put_slice(&self.buf[..cnt]); - self.buf.advance(cnt); - let _ = self.recv_stream.flow_control().release_capacity(cnt); - Poll::Ready(Ok(())) - } -} - -impl Write for H2Upgraded -where - B: Buf, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - self.send_stream.reserve_capacity(buf.len()); - - // We ignore all errors returned by `poll_capacity` and `write`, as we - // will get the correct from `poll_reset` anyway. - let cnt = match ready!(self.send_stream.poll_capacity(cx)) { - None => Some(0), - Some(Ok(cnt)) => self - .send_stream - .write(&buf[..cnt], false) - .ok() - .map(|()| cnt), - Some(Err(_)) => None, - }; - - if let Some(cnt) = cnt { - return Poll::Ready(Ok(cnt)); - } - - Poll::Ready(Err(h2_to_io_error( - match ready!(self.send_stream.poll_reset(cx)) { - Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) - } - Ok(reason) => reason.into(), - Err(e) => e, - }, - ))) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if self.send_stream.write(&[], true).is_ok() { - return Poll::Ready(Ok(())); - } - - Poll::Ready(Err(h2_to_io_error( - match ready!(self.send_stream.poll_reset(cx)) { - Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), - Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) - } - Ok(reason) => reason.into(), - Err(e) => e, - }, - ))) - } -} - -fn h2_to_io_error(e: h2::Error) -> std::io::Error { - if e.is_io() { - e.into_io().unwrap() - } else { - std::io::Error::new(std::io::ErrorKind::Other, e) - } -} - -struct UpgradedSendStream(SendStream>>); - -impl UpgradedSendStream -where - B: Buf, -{ - unsafe fn new(inner: SendStream>) -> Self { - assert_eq!(mem::size_of::(), mem::size_of::>()); - Self(mem::transmute(inner)) - } - - fn reserve_capacity(&mut self, cnt: usize) { - unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } - } - - fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll>> { - unsafe { self.as_inner_unchecked().poll_capacity(cx) } - } - - fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll> { - unsafe { self.as_inner_unchecked().poll_reset(cx) } - } - - fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), std::io::Error> { - let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); - unsafe { - self.as_inner_unchecked() - .send_data(send_buf, end_of_stream) - .map_err(h2_to_io_error) - } - } - - unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream> { - &mut *(&mut self.0 as *mut _ as *mut _) - } -} - -#[repr(transparent)] -struct Neutered { - _inner: B, - impossible: Impossible, -} - -enum Impossible {} - -unsafe impl Send for Neutered {} - -impl Buf for Neutered { - fn remaining(&self) -> usize { - match self.impossible {} - } - - fn chunk(&self) -> &[u8] { - match self.impossible {} - } - - fn advance(&mut self, _cnt: usize) { - match self.impossible {} - } -} diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 7995b349bf..483ed96dd9 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -19,9 +19,8 @@ use crate::common::time::Time; use crate::ext::Protocol; use crate::headers; use crate::proto::h2::ping::Recorder; -use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; use crate::proto::Dispatched; -use crate::rt::bounds::Http2ServerConnExec; +use crate::rt::bounds::{Http2ServerConnExec, Http2UpgradedExec}; use crate::rt::{Read, Write}; use crate::service::HttpService; @@ -308,6 +307,7 @@ where connect_parts, respond, self.date_header, + exec.clone(), ); exec.execute_h2stream(fut); @@ -357,7 +357,7 @@ where pin_project! { #[allow(missing_debug_implementations)] - pub struct H2Stream + pub struct H2Stream where B: Body, { @@ -365,6 +365,7 @@ pin_project! { #[pin] state: H2StreamState, date_header: bool, + exec: E, } } @@ -392,7 +393,7 @@ struct ConnectParts { recv_stream: RecvStream, } -impl H2Stream +impl H2Stream where B: Body, { @@ -401,11 +402,13 @@ where connect_parts: Option, respond: SendResponse>, date_header: bool, - ) -> H2Stream { + exec: E, + ) -> H2Stream { H2Stream { reply: respond, state: H2StreamState::Service { fut, connect_parts }, date_header, + exec, } } } @@ -423,16 +426,17 @@ macro_rules! reply { }}; } -impl H2Stream +impl H2Stream where F: Future, E>>, B: Body, B::Data: 'static, B::Error: Into>, + Ex: Http2UpgradedExec, E: Into>, { - fn poll2(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut me = self.project(); + fn poll2(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.as_mut().project(); loop { let next = match me.state.as_mut().project() { H2StreamStateProj::Service { @@ -488,15 +492,15 @@ where warn!("successful response to CONNECT request disallows content-length header"); } let send_stream = reply!(me, res, false); - connect_parts.pending.fulfill(Upgraded::new( - H2Upgraded { - ping: connect_parts.ping, - recv_stream: connect_parts.recv_stream, - send_stream: unsafe { UpgradedSendStream::new(send_stream) }, - buf: Bytes::new(), - }, - Bytes::new(), - )); + let (h2_up, up_task) = super::upgrade::pair( + send_stream, + connect_parts.recv_stream, + connect_parts.ping, + ); + connect_parts + .pending + .fulfill(Upgraded::new(h2_up, Bytes::new())); + self.exec.execute_upgrade(up_task); return Poll::Ready(Ok(())); } } @@ -525,12 +529,13 @@ where } } -impl Future for H2Stream +impl Future for H2Stream where F: Future, E>>, B: Body, B::Data: 'static, B::Error: Into>, + Ex: Http2UpgradedExec, E: Into>, { type Output = (); diff --git a/src/proto/h2/upgrade.rs b/src/proto/h2/upgrade.rs new file mode 100644 index 0000000000..40a98de08a --- /dev/null +++ b/src/proto/h2/upgrade.rs @@ -0,0 +1,280 @@ +use std::future::Future; +use std::io::Cursor; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::{Buf, Bytes}; +use futures_channel::{mpsc, oneshot}; +use futures_core::{ready, Stream}; +use h2::{Reason, RecvStream, SendStream}; +use pin_project_lite::pin_project; + +use super::ping::Recorder; +use super::SendBuf; +use crate::rt::{Read, ReadBufCursor, Write}; + +pub(super) fn pair( + send_stream: SendStream>, + recv_stream: RecvStream, + ping: Recorder, +) -> (H2Upgraded, UpgradedSendStreamTask) { + let (tx, rx) = mpsc::channel(1); + let (error_tx, error_rx) = oneshot::channel(); + + ( + H2Upgraded { + send_stream: UpgradedSendStreamBridge { tx, error_rx }, + recv_stream, + ping, + buf: Bytes::new(), + }, + UpgradedSendStreamTask { + h2_tx: send_stream, + rx, + error_tx: Some(error_tx), + }, + ) +} + +pub(super) struct H2Upgraded { + ping: Recorder, + send_stream: UpgradedSendStreamBridge, + recv_stream: RecvStream, + buf: Bytes, +} + +struct UpgradedSendStreamBridge { + tx: mpsc::Sender>>, + error_rx: oneshot::Receiver, +} + +pin_project! { + #[must_use = "futures do nothing unless polled"] + pub struct UpgradedSendStreamTask { + #[pin] + h2_tx: SendStream>, + #[pin] + rx: mpsc::Receiver>>, + error_tx: Option>, + } +} + +// ===== impl UpgradedSendStreamTask ===== + +impl UpgradedSendStreamTask +where + B: Buf, +{ + fn tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + // this is a manual `select()` over 3 "futures", so we always need + // to be sure they are ready and/or we are waiting notification of + // one of the sides hanging up, so the task doesn't live around + // longer than it's meant to. + loop { + // we don't have the next chunk of data yet, so just reserve 1 byte to make + // sure there's some capacity available. h2 will handle the capacity management + // for the actual body chunk. + me.h2_tx.reserve_capacity(1); + + if me.h2_tx.capacity() == 0 { + // poll_capacity oddly needs a loop + 'capacity: loop { + match me.h2_tx.poll_capacity(cx) { + Poll::Ready(Some(Ok(0))) => {} + Poll::Ready(Some(Ok(_))) => break, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Err(crate::Error::new_body_write(e))) + } + Poll::Ready(None) => { + // None means the stream is no longer in a + // streaming state, we either finished it + // somehow, or the remote reset us. + return Poll::Ready(Err(crate::Error::new_body_write( + "send stream capacity unexpectedly closed", + ))); + } + Poll::Pending => break 'capacity, + } + } + } + + match me.h2_tx.poll_reset(cx) { + Poll::Ready(Ok(reason)) => { + trace!("stream received RST_STREAM: {:?}", reason); + return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( + reason, + )))); + } + Poll::Ready(Err(err)) => { + return Poll::Ready(Err(crate::Error::new_body_write(err))) + } + Poll::Pending => (), + } + + match me.rx.as_mut().poll_next(cx) { + Poll::Ready(Some(cursor)) => { + me.h2_tx + .send_data(SendBuf::Cursor(cursor), false) + .map_err(crate::Error::new_body_write)?; + } + Poll::Ready(None) => { + me.h2_tx + .send_data(SendBuf::None, true) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +impl Future for UpgradedSendStreamTask +where + B: Buf, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().tick(cx) { + Poll::Ready(Ok(())) => Poll::Ready(()), + Poll::Ready(Err(err)) => { + if let Some(tx) = self.error_tx.take() { + let _oh_well = tx.send(err); + } + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } +} + +// ===== impl H2Upgraded ===== + +impl Read for H2Upgraded { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut read_buf: ReadBufCursor<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(match e.reason() { + Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), + Some(Reason::STREAM_CLOSED) => { + Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) + } + _ => Err(h2_to_io_error(e)), + }) + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl Write for H2Upgraded { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + match self.send_stream.tx.poll_ready(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(_task_dropped)) => { + // if the task dropped, check if there was an error + // otherwise i guess its a broken pipe + return match Pin::new(&mut self.send_stream.error_rx).poll(cx) { + Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), + Poll::Ready(Err(_task_dropped)) => { + Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) + } + Poll::Pending => Poll::Pending, + }; + } + Poll::Pending => return Poll::Pending, + } + + let n = buf.len(); + match self.send_stream.tx.start_send(Cursor::new(buf.into())) { + Ok(()) => Poll::Ready(Ok(n)), + Err(_task_dropped) => { + // if the task dropped, check if there was an error + // otherwise i guess its a broken pipe + match Pin::new(&mut self.send_stream.error_rx).poll(cx) { + Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), + Poll::Ready(Err(_task_dropped)) => { + Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) + } + Poll::Pending => Poll::Pending, + } + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.send_stream.tx.poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(_task_dropped)) => { + // if the task dropped, check if there was an error + // otherwise it was a clean close + match Pin::new(&mut self.send_stream.error_rx).poll(cx) { + Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), + Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + Poll::Pending => Poll::Pending, + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.send_stream.tx.close_channel(); + match Pin::new(&mut self.send_stream.error_rx).poll(cx) { + Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), + Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } +} + +fn io_error(e: crate::Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, e) +} + +fn h2_to_io_error(e: h2::Error) -> std::io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + std::io::Error::new(std::io::ErrorKind::Other, e) + } +} diff --git a/src/rt/bounds.rs b/src/rt/bounds.rs index aa3075e079..0af623b1bf 100644 --- a/src/rt/bounds.rs +++ b/src/rt/bounds.rs @@ -3,11 +3,34 @@ //! Traits in this module ease setting bounds and usually automatically //! implemented by implementing another trait. -#[cfg(all(feature = "server", feature = "http2"))] -pub use self::h2::Http2ServerConnExec; - #[cfg(all(feature = "client", feature = "http2"))] pub use self::h2_client::Http2ClientConnExec; +#[cfg(all(feature = "server", feature = "http2"))] +pub use self::h2_server::Http2ServerConnExec; + +#[cfg(all(any(feature = "client", feature = "server"), feature = "http2"))] +pub(crate) use self::h2_common::Http2UpgradedExec; + +#[cfg(all(any(feature = "client", feature = "server"), feature = "http2"))] +mod h2_common { + use crate::proto::h2::upgrade::UpgradedSendStreamTask; + use crate::rt::Executor; + + pub trait Http2UpgradedExec { + #[doc(hidden)] + fn execute_upgrade(&self, fut: UpgradedSendStreamTask); + } + + #[doc(hidden)] + impl Http2UpgradedExec for E + where + E: Executor>, + { + fn execute_upgrade(&self, fut: UpgradedSendStreamTask) { + self.execute(fut) + } + } +} #[cfg(all(feature = "client", feature = "http2"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "client", feature = "http2"))))] @@ -25,35 +48,40 @@ mod h2_client { /// This trait is sealed and cannot be implemented for types outside this crate. /// /// [`Executor`]: crate::rt::Executor - pub trait Http2ClientConnExec: sealed_client::Sealed<(B, T)> + pub trait Http2ClientConnExec: + super::Http2UpgradedExec + sealed_client::Sealed<(B, T)> + Clone where B: http_body::Body, B::Error: Into>, T: Read + Write + Unpin, { #[doc(hidden)] - fn execute_h2_future(&mut self, future: H2ClientFuture); + fn execute_h2_future(&mut self, future: H2ClientFuture); } impl Http2ClientConnExec for E where - E: Executor>, + E: Clone, + E: Executor>, + E: super::Http2UpgradedExec, B: http_body::Body + 'static, B::Error: Into>, - H2ClientFuture: Future, + H2ClientFuture: Future, T: Read + Write + Unpin, { - fn execute_h2_future(&mut self, future: H2ClientFuture) { + fn execute_h2_future(&mut self, future: H2ClientFuture) { self.execute(future) } } impl sealed_client::Sealed<(B, T)> for E where - E: Executor>, + E: Clone, + E: Executor>, + E: super::Http2UpgradedExec, B: http_body::Body + 'static, B::Error: Into>, - H2ClientFuture: Future, + H2ClientFuture: Future, T: Read + Write + Unpin, { } @@ -65,7 +93,7 @@ mod h2_client { #[cfg(all(feature = "server", feature = "http2"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "http2"))))] -mod h2 { +mod h2_server { use crate::{proto::h2::server::H2Stream, rt::Executor}; use http_body::Body; use std::future::Future; @@ -78,27 +106,33 @@ mod h2 { /// This trait is sealed and cannot be implemented for types outside this crate. /// /// [`Executor`]: crate::rt::Executor - pub trait Http2ServerConnExec: sealed::Sealed<(F, B)> + Clone { + pub trait Http2ServerConnExec: + super::Http2UpgradedExec + sealed::Sealed<(F, B)> + Clone + { #[doc(hidden)] - fn execute_h2stream(&mut self, fut: H2Stream); + fn execute_h2stream(&mut self, fut: H2Stream); } #[doc(hidden)] impl Http2ServerConnExec for E where - E: Executor> + Clone, - H2Stream: Future, + E: Clone, + E: Executor>, + E: super::Http2UpgradedExec, + H2Stream: Future, B: Body, { - fn execute_h2stream(&mut self, fut: H2Stream) { + fn execute_h2stream(&mut self, fut: H2Stream) { self.execute(fut) } } impl sealed::Sealed<(F, B)> for E where - E: Executor> + Clone, - H2Stream: Future, + E: Clone, + E: Executor>, + E: super::Http2UpgradedExec, + H2Stream: Future, B: Body, { }