Skip to content

Commit c2520b8

Browse files
authored
refactor: use two stage accept (#87)
2 parents 8144fde + b3c37ff commit c2520b8

File tree

10 files changed

+41
-22
lines changed

10 files changed

+41
-22
lines changed

examples/errors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async fn main() -> anyhow::Result<()> {
6060
let server = RpcServer::new(server);
6161
let handle = tokio::task::spawn(async move {
6262
for _ in 0..1 {
63-
let (req, chan) = server.accept().await?;
63+
let (req, chan) = server.accept().await?.read_first().await?;
6464
match req {
6565
IoRequest::Write(req) => chan.rpc_map_err(req, fs, Fs::write).await,
6666
}?

examples/modularize.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ async fn main() -> Result<()> {
3434
async fn run_server<C: ServiceEndpoint<AppService>>(server_conn: C, handler: app::Handler) {
3535
let server = RpcServer::new(server_conn);
3636
loop {
37-
match server.accept().await {
37+
let Ok(accepting) = server.accept().await else {
38+
continue;
39+
};
40+
match accepting.read_first().await {
3841
Err(err) => warn!(?err, "server accept failed"),
3942
Ok((req, chan)) => {
4043
let handler = handler.clone();

examples/store.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ async fn main() -> anyhow::Result<()> {
168168
let s = server;
169169
let store = Store;
170170
loop {
171-
let (req, chan) = s.accept().await?;
171+
let (req, chan) = s.accept().await?.read_first().await?;
172172
use StoreRequest::*;
173173
let store = store.clone();
174174
#[rustfmt::skip]

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
//! let handler = Handler;
6666
//! loop {
6767
//! // accept connections
68-
//! let (msg, chan) = server.accept().await?;
68+
//! let (msg, chan) = server.accept().await?.read_first().await?;
6969
//! // dispatch the message to the appropriate handler
7070
//! match msg {
7171
//! PingRequest::Ping(ping) => chan.rpc(ping, handler, Handler::ping).await?,

src/server.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,14 @@ where
117117
}
118118
}
119119

120-
impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
121-
/// Accepts a new channel from a client, and reads the first request.
120+
/// The result of accepting a new connection.
121+
pub struct Accepting<S: Service, C: ServiceEndpoint<S>> {
122+
send: C::SendSink,
123+
recv: C::RecvStream,
124+
}
125+
126+
impl<S: Service, C: ServiceEndpoint<S>> Accepting<S, C> {
127+
/// Read the first message from the client.
122128
///
123129
/// The return value is a tuple of `(request, channel)`. Here `request` is the
124130
/// first request which is already read from the stream. The `channel` is a
@@ -127,13 +133,8 @@ impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
127133
///
128134
/// Often sink and stream will wrap an an underlying byte stream. In this case you can
129135
/// call into_inner() on them to get it back to perform byte level reads and writes.
130-
pub async fn accept(&self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
131-
let (send, mut recv) = self
132-
.source
133-
.accept_bi()
134-
.await
135-
.map_err(RpcServerError::Accept)?;
136-
136+
pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
137+
let Accepting { send, mut recv } = self;
137138
// get the first message from the client. This will tell us what it wants to do.
138139
let request: S::Req = recv
139140
.next()
@@ -144,6 +145,19 @@ impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
144145
.map_err(RpcServerError::RecvError)?;
145146
Ok((request, RpcChannel::new(send, recv)))
146147
}
148+
}
149+
150+
impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
151+
/// Accepts a new channel from a client. The result is an [Accepting] object that
152+
/// can be used to read the first request.
153+
pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
154+
let (send, recv) = self
155+
.source
156+
.accept_bi()
157+
.await
158+
.map_err(RpcServerError::Accept)?;
159+
Ok(Accepting { send, recv })
160+
}
147161

148162
/// Get the underlying service endpoint
149163
pub fn into_inner(self) -> C {
@@ -309,7 +323,7 @@ where
309323
{
310324
let server = RpcServer::<S, C>::new(conn);
311325
loop {
312-
let (req, chan) = server.accept().await?;
326+
let (req, chan) = server.accept().await?.read_first().await?;
313327
let target = target.clone();
314328
handler(chan, req, target).await?;
315329
}

tests/flume.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> {
6767
tokio::task::spawn(async move {
6868
let service = ComputeService;
6969
loop {
70-
let (req, chan) = server.accept().await?;
70+
let (req, chan) = server.accept().await?.read_first().await?;
7171
let service = service.clone();
7272
tokio::spawn(async move {
7373
let req: OuterRequest = req;

tests/hyper.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {
176176
let (res_tx, res_rx) = flume::unbounded();
177177
let handle = tokio::spawn(async move {
178178
loop {
179-
let x = server.accept().await;
180-
let res = match x {
179+
let Ok(x) = server.accept().await else {
180+
continue;
181+
};
182+
let res = match x.read_first().await {
181183
Ok((req, chan)) => match req {
182184
TestRequest::BigRequest(req) => {
183185
chan.rpc(req, TestService, TestService::big).await

tests/math.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ impl ComputeService {
167167
let s = server;
168168
let service = ComputeService;
169169
loop {
170-
let (req, chan) = s.accept().await?;
170+
let (req, chan) = s.accept().await?.read_first().await?;
171171
let service = service.clone();
172172
tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await });
173173
}
@@ -206,7 +206,7 @@ impl ComputeService {
206206
let service = ComputeService;
207207
while received < count {
208208
received += 1;
209-
let (req, chan) = s.accept().await?;
209+
let (req, chan) = s.accept().await?.read_first().await?;
210210
let service = service.clone();
211211
tokio::spawn(async move {
212212
use ComputeRequest::*;
@@ -236,7 +236,7 @@ impl ComputeService {
236236
let service = ComputeService;
237237
let request_stream = stream! {
238238
loop {
239-
yield s2.accept().await;
239+
yield s2.accept().await?.read_first().await;
240240
}
241241
};
242242
let process_stream = request_stream.map(move |r| {

tests/slow_math.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl ComputeService {
113113
let s = server;
114114
let service = ComputeService;
115115
loop {
116-
let (req, chan) = s.accept().await?;
116+
let (req, chan) = s.accept().await?.read_first().await?;
117117
use ComputeRequest::*;
118118
let service = service.clone();
119119
#[rustfmt::skip]

tests/try.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async fn try_server_streaming() -> anyhow::Result<()> {
7777
let server = RpcServer::<TryService, _>::new(server);
7878
let server_handle = tokio::task::spawn(async move {
7979
loop {
80-
let (req, chan) = server.accept().await?;
80+
let (req, chan) = server.accept().await?.read_first().await?;
8181
let handler = Handler;
8282
match req {
8383
TryRequest::StreamN(req) => {

0 commit comments

Comments
 (0)