Skip to content

Commit 36dfa4c

Browse files
committed
feat: integrate list root and client info into hyper runtime
1 parent 5f9a966 commit 36dfa4c

File tree

3 files changed

+129
-22
lines changed

3 files changed

+129
-22
lines changed

crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use crate::{
44
mcp_server::HyperServer,
55
schema::{
66
schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
7-
CreateMessageRequestParams, CreateMessageResult, LoggingMessageNotificationParams,
7+
CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams,
8+
ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams,
89
PromptListChangedNotificationParams, ResourceListChangedNotificationParams,
910
ResourceUpdatedNotificationParams, ToolListChangedNotificationParams,
1011
},
@@ -99,6 +100,21 @@ impl HyperRuntime {
99100
runtime.send_notification(notification).await
100101
}
101102

103+
/// Request a list of root URIs from the client. Roots allow
104+
/// servers to ask for specific directories or files to operate on. A common example
105+
/// for roots is providing a set of repositories or directories a server should operate on.
106+
/// This request is typically used when the server needs to understand the file system
107+
/// structure or access specific locations that the client has permission to read from
108+
pub async fn list_roots(
109+
&self,
110+
session_id: &SessionId,
111+
params: Option<ListRootsRequestParams>,
112+
) -> SdkResult<ListRootsResult> {
113+
let runtime = self.runtime_by_session(session_id).await?;
114+
let runtime = runtime.lock().await.to_owned();
115+
runtime.list_roots(params).await
116+
}
117+
102118
pub async fn send_logging_message(
103119
&self,
104120
session_id: &SessionId,
@@ -195,4 +211,13 @@ impl HyperRuntime {
195211
let runtime = runtime.lock().await.to_owned();
196212
runtime.create_message(params).await
197213
}
214+
215+
pub async fn client_info(
216+
&self,
217+
session_id: &SessionId,
218+
) -> SdkResult<Option<InitializeRequestParams>> {
219+
let runtime = self.runtime_by_session(session_id).await?;
220+
let runtime = runtime.lock().await.to_owned();
221+
Ok(runtime.client_info())
222+
}
198223
}

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,9 @@ impl McpServer for ServerRuntime {
8787
.with_message("transport stream does not exists or is closed!".to_string()),
8888
)?;
8989

90-
// generate a new request_id for request messages
91-
let outgoing_request_id = if message.is_request() {
92-
match request_id {
93-
Some(_) => Err(RpcError::internal_error().with_message(
94-
"request_id should not have a value when sending a new request".to_string(),
95-
)),
96-
None => Ok(self.next_request_id(transport).await),
97-
}
98-
} else {
99-
Ok(request_id)
100-
}?;
90+
let outgoing_request_id = self
91+
.request_id_for_message(transport, &message, request_id)
92+
.await;
10193

10294
let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
10395
transport
@@ -225,7 +217,18 @@ impl ServerRuntime {
225217
Ok(())
226218
}
227219

228-
pub(crate) async fn next_request_id(
220+
/// Determines the request ID for an outgoing MCP message.
221+
///
222+
/// For requests, generates a new ID using the internal counter. For responses or errors,
223+
/// uses the provided `request_id`. Notifications receive no ID.
224+
///
225+
/// # Arguments
226+
/// * `message` - The MCP message to evaluate.
227+
/// * `request_id` - An optional existing request ID (required for responses/errors).
228+
///
229+
/// # Returns
230+
/// An `Option<RequestId>`: `Some` for requests or responses/errors, `None` for notifications.
231+
pub(crate) async fn request_id_for_message(
229232
&self,
230233
transport: &Arc<
231234
dyn TransportDispatcher<
@@ -236,12 +239,16 @@ impl ServerRuntime {
236239
ServerMessage,
237240
>,
238241
>,
242+
message: &MessageFromServer,
243+
request_id: Option<RequestId>,
239244
) -> Option<RequestId> {
240245
let message_sender = transport.message_sender();
241246
let guard = message_sender.read().await;
242-
guard
243-
.as_ref()
244-
.map(|dispatcher| dispatcher.next_request_id())
247+
if let Some(dispatcher) = guard.as_ref() {
248+
dispatcher.request_id_for_message(message, request_id)
249+
} else {
250+
None
251+
}
245252
}
246253

247254
pub(crate) async fn handle_message(

crates/rust-mcp-sdk/tests/test_streamable_http.rs

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, vec};
33
use hyper::StatusCode;
44
use rust_mcp_schema::{
55
schema_utils::{
6-
ClientJsonrpcRequest, ClientMessage, ClientMessages, FromMessage, NotificationFromServer,
7-
ResultFromServer, RpcMessage, SdkError, SdkErrorCodes, ServerJsonrpcNotification,
8-
ServerJsonrpcResponse, ServerMessages,
6+
ClientJsonrpcRequest, ClientJsonrpcResponse, ClientMessage, ClientMessages, FromMessage,
7+
NotificationFromServer, RequestFromServer, ResultFromServer, RpcMessage, SdkError,
8+
SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse,
9+
ServerMessages,
910
},
10-
CallToolRequest, CallToolRequestParams, ListToolsRequest, LoggingLevel,
11-
LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification,
12-
ServerResult,
11+
CallToolRequest, CallToolRequestParams, ListPromptsRequestParams, ListRootsRequestParams,
12+
ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId,
13+
RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult,
1314
};
1415
use rust_mcp_sdk::mcp_server::HyperServerOptions;
1516
use serde_json::{json, Map, Value};
@@ -364,6 +365,80 @@ async fn should_establish_standalone_stream_and_receive_server_messages() {
364365
server.hyper_runtime.await_server().await.unwrap()
365366
}
366367

368+
// should establish standalone SSE stream and receive server-initiated messages
369+
#[tokio::test]
370+
async fn should_establish_standalone_stream_and_receive_server_requests() {
371+
let (server, session_id) = initialize_server(None).await.unwrap();
372+
let response = get_standalone_stream(&server.streamable_url, &session_id).await;
373+
374+
assert_eq!(response.status(), StatusCode::OK);
375+
376+
assert_eq!(
377+
response
378+
.headers()
379+
.get("mcp-session-id")
380+
.unwrap()
381+
.to_str()
382+
.unwrap(),
383+
session_id
384+
);
385+
386+
assert_eq!(
387+
response
388+
.headers()
389+
.get("content-type")
390+
.unwrap()
391+
.to_str()
392+
.unwrap(),
393+
"text/event-stream"
394+
);
395+
396+
let hyper_server = Arc::new(server.hyper_runtime);
397+
let hyper_server_clone = hyper_server.clone();
398+
let session_id_clone = session_id.to_string();
399+
400+
tokio::spawn(async move {
401+
// Send a server-initiated notification that should appear on SSE stream with a valid request_id
402+
hyper_server_clone
403+
.list_roots(&session_id_clone, None)
404+
.await
405+
.unwrap();
406+
});
407+
408+
tokio::time::sleep(Duration::from_millis(2250)).await;
409+
410+
let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new(
411+
RequestId::Integer(0),
412+
ListRootsResult {
413+
meta: None,
414+
roots: vec![],
415+
}
416+
.into(),
417+
);
418+
419+
send_post_request(
420+
&server.streamable_url,
421+
&serde_json::to_string(&json_rpc_message).unwrap(),
422+
Some(&session_id),
423+
None,
424+
)
425+
.await
426+
.expect("Request failed");
427+
428+
let event = read_sse_event(response).await.unwrap();
429+
430+
let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap();
431+
432+
println!(">>> message {:?} ", message);
433+
434+
let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request
435+
else {
436+
panic!("invalid message received!");
437+
};
438+
439+
hyper_server.graceful_shutdown(ONE_MILLISECOND);
440+
}
441+
367442
// should not close GET SSE stream after sending multiple server notifications
368443
#[tokio::test]
369444
async fn should_not_close_get_sse_stream() {

0 commit comments

Comments
 (0)