Skip to content

Commit a2fc90a

Browse files
authored
Skip notification in initialization handshake (#421)
* fix: handle logging and ping in handshake We handle the initialization process more robustly. - Allow logging and ping - For other messages, we simply ignore it instead of rejecting right away * fix: inject context to notification handler
1 parent 452fe2c commit a2fc90a

File tree

1 file changed

+57
-8
lines changed

1 file changed

+57
-8
lines changed

crates/rmcp/src/service/client.rs

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,59 @@ where
7575
}
7676

7777
/// Helper function to expect a response from the stream
78-
async fn expect_response<T>(
78+
async fn expect_response<T, S>(
7979
transport: &mut T,
8080
context: &str,
81+
service: &S,
82+
peer: Peer<RoleClient>,
8183
) -> Result<(ServerResult, RequestId), ClientInitializeError>
8284
where
8385
T: Transport<RoleClient>,
86+
S: Service<RoleClient>,
8487
{
85-
let msg = expect_next_message(transport, context).await?;
86-
87-
match msg {
88-
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
89-
_ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
88+
loop {
89+
let message = expect_next_message(transport, context).await?;
90+
match message {
91+
// Expected message to complete the initialization
92+
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
93+
break Ok((result, id));
94+
}
95+
// Server could send logging messages before handshake
96+
ServerJsonRpcMessage::Notification(mut notification) => {
97+
let ServerNotification::LoggingMessageNotification(logging) =
98+
&mut notification.notification
99+
else {
100+
tracing::warn!(?notification, "Received unexpected message");
101+
continue;
102+
};
103+
104+
let mut context = NotificationContext {
105+
peer: peer.clone(),
106+
meta: Meta::default(),
107+
extensions: Extensions::default(),
108+
};
109+
110+
if let Some(meta) = logging.extensions.get_mut::<Meta>() {
111+
std::mem::swap(&mut context.meta, meta);
112+
}
113+
std::mem::swap(&mut context.extensions, &mut logging.extensions);
114+
115+
if let Err(error) = service
116+
.handle_notification(notification.notification, context)
117+
.await
118+
{
119+
tracing::warn!(?error, "Handle logging before handshake failed.");
120+
}
121+
}
122+
// Server could send pings before handshake
123+
ServerJsonRpcMessage::Request(ref request)
124+
if matches!(request.request, ServerRequest::PingRequest(_)) =>
125+
{
126+
tracing::trace!("Received ping request. Ignored.")
127+
}
128+
// Server SHOULD NOT send any other messages before handshake. We ignore them anyway
129+
_ => tracing::warn!(?message, "Received unexpected message"),
130+
}
90131
}
91132
}
92133

@@ -183,7 +224,15 @@ where
183224
context: "send initialize request".into(),
184225
})?;
185226

186-
let (response, response_id) = expect_response(&mut transport, "initialize response").await?;
227+
let (peer, peer_rx) = Peer::new(id_provider, None);
228+
229+
let (response, response_id) = expect_response(
230+
&mut transport,
231+
"initialize response",
232+
&service,
233+
peer.clone(),
234+
)
235+
.await?;
187236

188237
if id != response_id {
189238
return Err(ClientInitializeError::ConflictInitResponseId(
@@ -195,6 +244,7 @@ where
195244
let ServerResult::InitializeResult(initialize_result) = response else {
196245
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
197246
};
247+
peer.set_peer_info(initialize_result);
198248

199249
// send notification
200250
let notification = ClientJsonRpcMessage::notification(
@@ -206,7 +256,6 @@ where
206256
transport.send(notification).await.map_err(|error| {
207257
ClientInitializeError::transport::<T>(error, "send initialized notification")
208258
})?;
209-
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
210259
Ok(serve_inner(service, transport, peer, peer_rx, ct))
211260
}
212261

0 commit comments

Comments
 (0)