Skip to content

Commit 287b713

Browse files
committed
refactor: request-id generation and messaging functions
1 parent 080b3a5 commit 287b713

File tree

10 files changed

+233
-190
lines changed

10 files changed

+233
-190
lines changed

crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@ use crate::schema::{
77
ServerMessages,
88
},
99
InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification,
10-
RpcError, ServerResult,
10+
RequestId, RpcError, ServerResult,
1111
};
1212
use async_trait::async_trait;
1313
use futures::future::{join_all, try_join_all};
1414
use futures::StreamExt;
1515

16-
use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport};
17-
use std::sync::{Arc, RwLock};
16+
use rust_mcp_transport::{
17+
IoStream, McpDispatch, MessageDispatcher, RequestIdGen, RequestIdGenNumeric, Transport,
18+
};
19+
use std::{
20+
sync::{Arc, RwLock},
21+
time::Duration,
22+
};
1823
use tokio::io::{AsyncBufReadExt, BufReader};
1924
use tokio::sync::Mutex;
2025

@@ -41,6 +46,7 @@ pub struct ClientRuntime {
4146
// Details about the connected server
4247
server_details: Arc<RwLock<Option<InitializeResult>>>,
4348
handlers: Mutex<Vec<tokio::task::JoinHandle<Result<(), McpSdkError>>>>,
49+
request_id_gen: Box<dyn RequestIdGen>,
4450
}
4551

4652
impl ClientRuntime {
@@ -61,6 +67,7 @@ impl ClientRuntime {
6167
client_details,
6268
server_details: Arc::new(RwLock::new(None)),
6369
handlers: Mutex::new(vec![]),
70+
request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
6471
}
6572
}
6673

@@ -284,6 +291,33 @@ impl McpClient for ClientRuntime {
284291
}
285292
}
286293

294+
async fn send(
295+
&self,
296+
message: MessageFromClient,
297+
request_id: Option<RequestId>,
298+
timeout: Option<Duration>,
299+
) -> SdkResult<Option<ServerMessage>> {
300+
let sender = self.sender();
301+
let sender = sender.read().await;
302+
let sender = sender
303+
.as_ref()
304+
.ok_or(schema_utils::SdkError::connection_closed())?;
305+
306+
let outgoing_request_id = self
307+
.request_id_gen
308+
.request_id_for_message(&message, request_id);
309+
310+
let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?;
311+
312+
let response = sender
313+
.send_message(ClientMessages::Single(mcp_message), timeout)
314+
.await?
315+
.map(|res| res.as_single())
316+
.transpose()?;
317+
318+
Ok(response)
319+
}
320+
287321
async fn is_shut_down(&self) -> bool {
288322
self.transport.is_shut_down().await
289323
}

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use futures::future::try_join_all;
1616
use futures::{StreamExt, TryFutureExt};
1717
#[cfg(feature = "hyper-server")]
1818
use rust_mcp_transport::SessionId;
19-
use rust_mcp_transport::{IoStream, TransportDispatcher};
19+
use rust_mcp_transport::{IoStream, RequestIdGen, RequestIdGenNumeric, TransportDispatcher};
2020
use std::collections::HashMap;
2121
use std::sync::Arc;
2222
use std::time::Duration;
@@ -45,6 +45,7 @@ pub struct ServerRuntime {
4545
#[cfg(feature = "hyper-server")]
4646
session_id: Option<SessionId>,
4747
transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
48+
request_id_gen: Box<dyn RequestIdGen>,
4849
client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
4950
client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
5051
}
@@ -79,22 +80,26 @@ impl McpServer for ServerRuntime {
7980
message: MessageFromServer,
8081
request_id: Option<RequestId>,
8182
request_timeout: Option<Duration>,
82-
) -> SdkResult<Option<ClientMessages>> {
83+
) -> SdkResult<Option<ClientMessage>> {
8384
let transport_map = self.transport_map.read().await;
8485
let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or(
8586
RpcError::internal_error()
8687
.with_message("transport stream does not exists or is closed!".to_string()),
8788
)?;
8889

8990
let outgoing_request_id = self
90-
.request_id_for_message(transport, &message, request_id)
91-
.await;
91+
.request_id_gen
92+
.request_id_for_message(&message, request_id);
9293

9394
let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
94-
transport
95+
96+
let response = transport
9597
.send_message(ServerMessages::Single(mcp_message), request_timeout)
96-
.map_err(|err| err.into())
97-
.await
98+
.await?
99+
.map(|res| res.as_single())
100+
.transpose()?;
101+
102+
Ok(response)
98103
}
99104

100105
async fn send_batch(
@@ -211,40 +216,6 @@ impl ServerRuntime {
211216
Ok(())
212217
}
213218

214-
/// Determines the request ID for an outgoing MCP message.
215-
///
216-
/// For requests, generates a new ID using the internal counter. For responses or errors,
217-
/// uses the provided `request_id`. Notifications receive no ID.
218-
///
219-
/// # Arguments
220-
/// * `message` - The MCP message to evaluate.
221-
/// * `request_id` - An optional existing request ID (required for responses/errors).
222-
///
223-
/// # Returns
224-
/// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
225-
pub(crate) async fn request_id_for_message(
226-
&self,
227-
transport: &Arc<
228-
dyn TransportDispatcher<
229-
ClientMessages,
230-
MessageFromServer,
231-
ClientMessage,
232-
ServerMessages,
233-
ServerMessage,
234-
>,
235-
>,
236-
message: &MessageFromServer,
237-
request_id: Option<RequestId>,
238-
) -> Option<RequestId> {
239-
let message_sender = transport.message_sender();
240-
let guard = message_sender.read().await;
241-
if let Some(dispatcher) = guard.as_ref() {
242-
dispatcher.request_id_for_message(message, request_id)
243-
} else {
244-
None
245-
}
246-
}
247-
248219
pub(crate) async fn handle_message(
249220
&self,
250221
message: ClientMessage,
@@ -471,6 +442,7 @@ impl ServerRuntime {
471442
transport_map: tokio::sync::RwLock::new(HashMap::new()),
472443
client_details_tx,
473444
client_details_rx,
445+
request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
474446
}
475447
}
476448

@@ -497,6 +469,7 @@ impl ServerRuntime {
497469
transport_map: tokio::sync::RwLock::new(map),
498470
client_details_tx,
499471
client_details_rx,
472+
request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
500473
}
501474
}
502475
}

crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::schema::{
1010
InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams,
1111
ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest,
1212
ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams,
13-
LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams,
13+
LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId,
1414
RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities,
1515
SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams,
1616
UnsubscribeRequest, UnsubscribeRequestParams,
@@ -175,27 +175,15 @@ pub trait McpClient: Sync + Send {
175175
request: RequestFromClient,
176176
timeout: Option<Duration>,
177177
) -> SdkResult<ResultFromServer> {
178-
let sender = self.sender();
179-
let sender = sender.read().await;
180-
let sender = sender
181-
.as_ref()
182-
.ok_or(schema_utils::SdkError::connection_closed())?;
183-
184-
let request_id = sender.next_request_id();
185-
186-
let mcp_message =
187-
ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?;
188-
let response = sender
189-
.send_message(ClientMessages::Single(mcp_message), timeout)
178+
let response = self
179+
.send(MessageFromClient::RequestFromClient(request), None, timeout)
190180
.await?;
191181

192182
let server_message = response.ok_or_else(|| {
193183
RpcError::internal_error()
194-
.with_message("An empty response was received from the server.".to_string())
184+
.with_message("An empty response was received from the client.".to_string())
195185
})?;
196186

197-
let server_message = server_message.as_single()?;
198-
199187
if server_message.is_error() {
200188
return Err(server_message.as_error()?.error.into());
201189
}
@@ -205,27 +193,10 @@ pub trait McpClient: Sync + Send {
205193

206194
async fn send(
207195
&self,
208-
message: ClientMessage,
196+
message: MessageFromClient,
197+
request_id: Option<RequestId>,
209198
timeout: Option<Duration>,
210-
) -> SdkResult<Option<ServerMessage>> {
211-
let sender = self.sender();
212-
let sender = sender.read().await;
213-
let sender = sender
214-
.as_ref()
215-
.ok_or(schema_utils::SdkError::connection_closed())?;
216-
217-
let response = sender
218-
.send_message(ClientMessages::Single(message), timeout)
219-
.await?;
220-
221-
match response {
222-
Some(res) => {
223-
let server_results = res.as_single()?;
224-
Ok(Some(server_results))
225-
}
226-
None => Ok(None),
227-
}
228-
}
199+
) -> SdkResult<Option<ServerMessage>>;
229200

230201
async fn send_batch(
231202
&self,

crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::time::Duration;
22

33
use crate::schema::{
44
schema_utils::{
5-
ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer,
6-
RequestFromServer, ResultFromClient, ServerMessage,
5+
ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer,
6+
ResultFromClient, ServerMessage,
77
},
88
CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult,
99
GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult,
@@ -44,7 +44,7 @@ pub trait McpServer: Sync + Send {
4444
message: MessageFromServer,
4545
request_id: Option<RequestId>,
4646
request_timeout: Option<Duration>,
47-
) -> SdkResult<Option<ClientMessages>>;
47+
) -> SdkResult<Option<ClientMessage>>;
4848

4949
async fn send_batch(
5050
&self,
@@ -84,13 +84,11 @@ pub trait McpServer: Sync + Send {
8484
.send(MessageFromServer::RequestFromServer(request), None, timeout)
8585
.await?;
8686

87-
let client_messages = response.ok_or_else(|| {
87+
let client_message = response.ok_or_else(|| {
8888
RpcError::internal_error()
8989
.with_message("An empty response was received from the client.".to_string())
9090
})?;
9191

92-
let client_message = client_messages.as_single()?;
93-
9492
if client_message.is_error() {
9593
return Err(client_message.as_error()?.error.into());
9694
}

crates/rust-mcp-sdk/tests/common/common.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ use futures::stream::Stream;
128128
// stream: &mut impl Stream<Item = Result<hyper::body::Bytes, hyper::Error>>,
129129
pub async fn read_sse_event_from_stream(
130130
stream: &mut (impl Stream<Item = Result<hyper::body::Bytes, reqwest::Error>> + Unpin),
131-
) -> Option<String> {
131+
event_count: usize,
132+
) -> Option<Vec<String>> {
132133
let mut buffer = String::new();
134+
let mut events = vec![];
133135

134136
while let Some(item) = stream.next().await {
135137
match item {
@@ -158,7 +160,10 @@ pub async fn read_sse_event_from_stream(
158160

159161
// Return if data was found
160162
if let Some(data) = data {
161-
return Some(data);
163+
events.push(data);
164+
if events.len().eq(&event_count) {
165+
return Some(events);
166+
}
162167
}
163168
}
164169
}
@@ -171,9 +176,9 @@ pub async fn read_sse_event_from_stream(
171176
None
172177
}
173178

174-
pub async fn read_sse_event(response: Response) -> Option<String> {
179+
pub async fn read_sse_event(response: Response, event_count: usize) -> Option<Vec<String>> {
175180
let mut stream = response.bytes_stream();
176-
read_sse_event_from_stream(&mut stream).await
181+
read_sse_event_from_stream(&mut stream, event_count).await
177182
}
178183

179184
pub fn test_client_info() -> InitializeRequestParams {

0 commit comments

Comments
 (0)