Skip to content

Commit 0a99a3c

Browse files
committed
Remove allocations
1 parent 21f4d0e commit 0a99a3c

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

src/lib.rs

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,27 @@ use SignatureAlgorithm::{
1919
EcdsaSha256, EcdsaSha384, Ed25519, NoSignature, RsaSha1, RsaSha256, RsaSha384, RsaSha512,
2020
};
2121

22+
mod private {
23+
use super::*;
24+
25+
pub struct TlsConnectFuture<S> {
26+
pub inner: tokio_rustls::Connect<S>,
27+
}
28+
29+
impl<S> Future for TlsConnectFuture<S>
30+
where
31+
S: AsyncRead + AsyncWrite + Unpin,
32+
{
33+
type Output = io::Result<RustlsStream<S>>;
34+
35+
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
36+
// SAFETY: If `self` is pinned, so is `inner`.
37+
let fut = unsafe { self.map_unchecked_mut(|this| &mut this.inner) };
38+
fut.poll(cx).map_ok(RustlsStream)
39+
}
40+
}
41+
}
42+
2243
#[derive(Clone)]
2344
pub struct MakeRustlsConnect {
2445
config: Arc<ClientConfig>,
@@ -63,20 +84,23 @@ where
6384
{
6485
type Stream = RustlsStream<S>;
6586
type Error = io::Error;
66-
type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
87+
type Future = private::TlsConnectFuture<S>;
6788

6889
fn connect(self, stream: S) -> Self::Future {
69-
Box::pin(async move {
70-
self.0
71-
.connector
72-
.connect(self.0.hostname, stream)
73-
.await
74-
.map(|s| RustlsStream(Box::pin(s)))
75-
})
90+
private::TlsConnectFuture {
91+
inner: self.0.connector.connect(self.0.hostname, stream),
92+
}
7693
}
7794
}
7895

79-
pub struct RustlsStream<S>(Pin<Box<TlsStream<S>>>);
96+
pub struct RustlsStream<S>(TlsStream<S>);
97+
98+
impl<S> RustlsStream<S> {
99+
pub fn project_stream(self: Pin<&mut Self>) -> Pin<&mut TlsStream<S>> {
100+
// SAFETY: When `Self` is pinned, so is the inner `TlsStream`.
101+
unsafe { self.map_unchecked_mut(|this| &mut this.0) }
102+
}
103+
}
80104

81105
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
82106
where
@@ -115,11 +139,11 @@ where
115139
S: AsyncRead + AsyncWrite + Unpin,
116140
{
117141
fn poll_read(
118-
mut self: Pin<&mut Self>,
142+
self: Pin<&mut Self>,
119143
cx: &mut Context,
120144
buf: &mut ReadBuf<'_>,
121145
) -> Poll<tokio::io::Result<()>> {
122-
self.0.as_mut().poll_read(cx, buf)
146+
self.project_stream().poll_read(cx, buf)
123147
}
124148
}
125149

@@ -128,19 +152,19 @@ where
128152
S: AsyncRead + AsyncWrite + Unpin,
129153
{
130154
fn poll_write(
131-
mut self: Pin<&mut Self>,
155+
self: Pin<&mut Self>,
132156
cx: &mut Context,
133157
buf: &[u8],
134158
) -> Poll<tokio::io::Result<usize>> {
135-
self.0.as_mut().poll_write(cx, buf)
159+
self.project_stream().poll_write(cx, buf)
136160
}
137161

138-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
139-
self.0.as_mut().poll_flush(cx)
162+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
163+
self.project_stream().poll_flush(cx)
140164
}
141165

142-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
143-
self.0.as_mut().poll_shutdown(cx)
166+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
167+
self.project_stream().poll_shutdown(cx)
144168
}
145169
}
146170

0 commit comments

Comments
 (0)