Skip to content

Commit bcc4024

Browse files
authored
Merge pull request #26 from n0-computer/rklaehn/poison-mpsc-sender
Make sure sending fails after the first io error or future drop
2 parents b36a4b4 + 10dc097 commit bcc4024

File tree

4 files changed

+188
-8
lines changed

4 files changed

+188
-8
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ tokio = { workspace = true, features = ["full"] }
5757
thousands = "0.2.0"
5858
# macro tests
5959
trybuild = "1.0.104"
60+
testresult = "0.4.1"
6061

6162
[features]
6263
# enable the remote transport

src/lib.rs

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ pub mod channel {
389389
}
390390
}
391391

392-
/// A sender that can be wrapped in a `Box<dyn DynSender<T>>`.
392+
/// A sender that can be wrapped in a `Arc<dyn DynSender<T>>`.
393393
pub trait DynSender<T>: Debug + Send + Sync + 'static {
394394
/// Send a message.
395395
///
@@ -446,6 +446,13 @@ pub mod channel {
446446

447447
impl<T: RpcMessage> Sender<T> {
448448
/// Send a message and yield until either it is sent or an error occurs.
449+
///
450+
/// ## Cancellation safety
451+
///
452+
/// If the future is dropped before completion, and if this is a remote sender,
453+
/// then the sender will be closed and further sends will return an [`io::Error`]
454+
/// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
455+
/// future until completion if you want to reuse the sender or any clone afterwards.
449456
pub async fn send(&self, value: T) -> std::result::Result<(), SendError> {
450457
match self {
451458
Sender::Tokio(tx) => {
@@ -469,6 +476,13 @@ pub mod channel {
469476
/// all.
470477
///
471478
/// Returns true if the message was sent.
479+
///
480+
/// ## Cancellation safety
481+
///
482+
/// If the future is dropped before completion, and if this is a remote sender,
483+
/// then the sender will be closed and further sends will return an [`io::Error`]
484+
/// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the
485+
/// future until completion if you want to reuse the sender or any clone afterwards.
472486
pub async fn try_send(&mut self, value: T) -> std::result::Result<bool, SendError> {
473487
match self {
474488
Sender::Tokio(tx) => match tx.try_send(value) {
@@ -1092,7 +1106,9 @@ pub mod rpc {
10921106
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))]
10931107
pub mod rpc {
10941108
//! Module for cross-process RPC using [`quinn`].
1095-
use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
1109+
use std::{
1110+
fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc,
1111+
};
10961112

10971113
use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
10981114
use quinn::ConnectionError;
@@ -1307,11 +1323,11 @@ pub mod rpc {
13071323
impl<T: RpcMessage> From<quinn::SendStream> for spsc::Sender<T> {
13081324
fn from(write: quinn::SendStream) -> Self {
13091325
spsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new(
1310-
QuinnSenderInner {
1326+
QuinnSenderState::Open(QuinnSenderInner {
13111327
send: write,
13121328
buffer: SmallVec::new(),
13131329
_marker: PhantomData,
1314-
},
1330+
}),
13151331
))))
13161332
}
13171333
}
@@ -1400,7 +1416,14 @@ pub mod rpc {
14001416
}
14011417
}
14021418

1403-
struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderInner<T>>);
1419+
#[derive(Default)]
1420+
enum QuinnSenderState<T> {
1421+
Open(QuinnSenderInner<T>),
1422+
#[default]
1423+
Closed,
1424+
}
1425+
1426+
struct QuinnSender<T>(tokio::sync::Mutex<QuinnSenderState<T>>);
14041427

14051428
impl<T> Debug for QuinnSender<T> {
14061429
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -1413,18 +1436,50 @@ pub mod rpc {
14131436
&self,
14141437
value: T,
14151438
) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + Sync + '_>> {
1416-
Box::pin(async { self.0.lock().await.send(value).await })
1439+
Box::pin(async {
1440+
let mut guard = self.0.lock().await;
1441+
let sender = std::mem::take(guard.deref_mut());
1442+
match sender {
1443+
QuinnSenderState::Open(mut sender) => {
1444+
let res = sender.send(value).await;
1445+
if res.is_ok() {
1446+
*guard = QuinnSenderState::Open(sender);
1447+
}
1448+
res
1449+
}
1450+
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
1451+
}
1452+
})
14171453
}
14181454

14191455
fn try_send(
14201456
&self,
14211457
value: T,
14221458
) -> Pin<Box<dyn Future<Output = io::Result<bool>> + Send + Sync + '_>> {
1423-
Box::pin(async { self.0.lock().await.try_send(value).await })
1459+
Box::pin(async {
1460+
let mut guard = self.0.lock().await;
1461+
let sender = std::mem::take(guard.deref_mut());
1462+
match sender {
1463+
QuinnSenderState::Open(mut sender) => {
1464+
let res = sender.try_send(value).await;
1465+
if res.is_ok() {
1466+
*guard = QuinnSenderState::Open(sender);
1467+
}
1468+
res
1469+
}
1470+
QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
1471+
}
1472+
})
14241473
}
14251474

14261475
fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync + '_>> {
1427-
Box::pin(async { self.0.lock().await.closed().await })
1476+
Box::pin(async {
1477+
let mut guard = self.0.lock().await;
1478+
match guard.deref_mut() {
1479+
QuinnSenderState::Open(sender) => sender.closed().await,
1480+
QuinnSenderState::Closed => {}
1481+
}
1482+
})
14281483
}
14291484

14301485
fn is_rpc(&self) -> bool {

tests/mpsc_sender.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use std::{
2+
io::ErrorKind,
3+
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4+
time::Duration,
5+
};
6+
7+
use irpc::{
8+
channel::{spsc, SendError},
9+
util::{make_client_endpoint, make_server_endpoint},
10+
};
11+
use quinn::Endpoint;
12+
use testresult::TestResult;
13+
use tokio::time::timeout;
14+
15+
fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> {
16+
let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into();
17+
let (server, cert) = make_server_endpoint(addr)?;
18+
let client = make_client_endpoint(addr, &[cert.as_slice()])?;
19+
let port = server.local_addr()?.port();
20+
let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into();
21+
Ok((server, client, server_addr))
22+
}
23+
24+
/// Checks that all clones of a `Sender` will get the closed signal as soon as
25+
/// a send fails with an io error.
26+
#[tokio::test]
27+
async fn mpsc_sender_clone_closed_error() -> TestResult<()> {
28+
tracing_subscriber::fmt::try_init().ok();
29+
let (server, client, server_addr) = create_connected_endpoints()?;
30+
// accept a single bidi stream on a single connection, then immediately stop it
31+
let server = tokio::spawn(async move {
32+
let conn = server.accept().await.unwrap().await?;
33+
let (_, mut recv) = conn.accept_bi().await?;
34+
recv.stop(1u8.into())?;
35+
TestResult::Ok(())
36+
});
37+
let conn = client.connect(server_addr, "localhost")?.await?;
38+
let (send, _) = conn.open_bi().await?;
39+
let send1 = spsc::Sender::<Vec<u8>>::from(send);
40+
let send2 = send1.clone();
41+
let send3 = send1.clone();
42+
let second_client = tokio::spawn(async move {
43+
send2.closed().await;
44+
});
45+
let third_client = tokio::spawn(async move {
46+
// this should fail with an io error, since the stream was stopped
47+
loop {
48+
match send3.send(vec![1, 2, 3]).await {
49+
Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break,
50+
_ => {}
51+
};
52+
}
53+
});
54+
// send until we get an error because the remote side stopped the stream
55+
while send1.send(vec![1, 2, 3]).await.is_ok() {}
56+
match send1.send(vec![4, 5, 6]).await {
57+
Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {}
58+
e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e),
59+
};
60+
// check that closed signal was received by the second sender
61+
second_client.await?;
62+
// check that the third sender will get the right kind of io error eventually
63+
third_client.await?;
64+
// server should finish without errors
65+
server.await??;
66+
Ok(())
67+
}
68+
69+
/// Checks that all clones of a `Sender` will get the closed signal as soon as
70+
/// a send future gets dropped before completing.
71+
#[tokio::test]
72+
async fn mpsc_sender_clone_drop_error() -> TestResult<()> {
73+
let (server, client, server_addr) = create_connected_endpoints()?;
74+
// accept a single bidi stream on a single connection, then read indefinitely
75+
// until we get an error or the stream is finished
76+
let server = tokio::spawn(async move {
77+
let conn = server.accept().await.unwrap().await?;
78+
let (_, mut recv) = conn.accept_bi().await?;
79+
let mut buf = vec![0u8; 1024];
80+
while let Ok(Some(_)) = recv.read(&mut buf).await {}
81+
TestResult::Ok(())
82+
});
83+
let conn = client.connect(server_addr, "localhost")?.await?;
84+
let (send, _) = conn.open_bi().await?;
85+
let send1 = spsc::Sender::<Vec<u8>>::from(send);
86+
let send2 = send1.clone();
87+
let send3 = send1.clone();
88+
let second_client = tokio::spawn(async move {
89+
send2.closed().await;
90+
});
91+
let third_client = tokio::spawn(async move {
92+
// this should fail with an io error, since the stream was stopped
93+
loop {
94+
match send3.send(vec![1, 2, 3]).await {
95+
Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break,
96+
_ => {}
97+
};
98+
}
99+
});
100+
// send a lot of data with a tiny timeout, this will cause the send future to be dropped
101+
loop {
102+
let send_future = send1.send(vec![0u8; 1024 * 1024]);
103+
// not sure if there is a better way. I want to poll the future a few times so it has time to
104+
// start sending, but don't want to give it enough time to complete.
105+
// I don't think now_or_never would work, since it wouldn't have time to start sending
106+
if timeout(Duration::from_micros(1), send_future)
107+
.await
108+
.is_err()
109+
{
110+
break;
111+
}
112+
}
113+
server.await??;
114+
second_client.await?;
115+
third_client.await?;
116+
Ok(())
117+
}

0 commit comments

Comments
 (0)