Skip to content

Commit 0d70863

Browse files
committed
add --override-protocol-version
1 parent 7020b7c commit 0d70863

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

src/cli.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ pub struct Args {
1818
/// Initial retry interal in seconds. Default is 5 seconds
1919
#[arg(long, default_value = "5")]
2020
pub initial_retry_interval: u64,
21+
22+
#[arg(long)]
23+
/// Override the protocol version returned to the client
24+
pub override_protocol_version: Option<String>,
2125
}

src/core.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,9 @@ pub(crate) async fn connect_with_streamable(app_state: &AppState) -> Result<SseC
6464
rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig {
6565
uri: app_state.url.clone().into(),
6666
// we don't want the sdk to perform any retries
67-
retry_config: std::sync::Arc::new(
68-
rmcp::transport::common::client_side_sse::FixedInterval {
69-
max_times: Some(0),
70-
duration: Duration::from_millis(0),
71-
},
72-
),
67+
retry_config: std::sync::Arc::new(rmcp::transport::common::client_side_sse::NeverRetry),
7368
channel_buffer_capacity: 16,
69+
allow_stateless: true,
7470
},
7571
);
7672

src/main.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::{Context, Result, anyhow};
22
use clap::Parser;
33
use futures::StreamExt;
44
use rmcp::{
5-
model::{ClientJsonRpcMessage, ErrorCode, ServerJsonRpcMessage},
5+
model::{ClientJsonRpcMessage, ErrorCode, ProtocolVersion, ServerJsonRpcMessage},
66
transport::{StreamableHttpClientTransport, Transport, sse_client::SseClientTransport},
77
};
88
use std::env;
@@ -128,12 +128,33 @@ async fn main() -> Result<()> {
128128
debug!("Starting MCP proxy with URL: {}", sse_url);
129129
debug!("Max disconnected time: {:?}s", args.max_disconnected_time);
130130

131+
// Parse protocol version override if provided
132+
let override_protocol_version = if let Some(version_str) = args.override_protocol_version {
133+
let protocol_version = match version_str.as_str() {
134+
"2024-11-05" => ProtocolVersion::V_2024_11_05,
135+
"2025-03-26" => ProtocolVersion::V_2025_03_26,
136+
_ => {
137+
return Err(anyhow!(
138+
"Unsupported protocol version: {}. Supported versions are: 2024-11-05, 2025-03-26",
139+
version_str
140+
));
141+
}
142+
};
143+
Some(protocol_version)
144+
} else {
145+
None
146+
};
147+
131148
// Set up communication channels
132149
let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(10);
133150
let (timer_tx, mut timer_rx) = tokio::sync::mpsc::channel(10);
134151

135152
// Initialize application state
136-
let mut app_state = AppState::new(sse_url.clone(), args.max_disconnected_time);
153+
let mut app_state = AppState::new(
154+
sse_url.clone(),
155+
args.max_disconnected_time,
156+
override_protocol_version,
157+
);
137158
// Pass channel senders to state
138159
app_state.reconnect_tx = Some(reconnect_tx.clone());
139160
app_state.timer_tx = Some(timer_tx.clone());

src/state.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use anyhow::Result;
88
use futures::SinkExt;
99
use rmcp::model::{
1010
ClientJsonRpcMessage, ClientNotification, ClientRequest, EmptyResult, InitializedNotification,
11-
InitializedNotificationMethod, RequestId, ServerJsonRpcMessage,
11+
InitializedNotificationMethod, ProtocolVersion, RequestId, ServerJsonRpcMessage, ServerResult,
1212
};
1313
use std::collections::HashMap;
1414
use std::time::{Duration, Instant};
@@ -49,6 +49,8 @@ pub struct AppState {
4949
pub url: String,
5050
/// Maximum time to try reconnecting in seconds (None = infinity)
5151
pub max_disconnected_time: Option<u64>,
52+
/// Override protocol version
53+
pub override_protocol_version: Option<ProtocolVersion>,
5254
/// When we were disconnected
5355
pub disconnected_since: Option<Instant>,
5456
/// Current state of the application
@@ -78,10 +80,15 @@ pub struct AppState {
7880
}
7981

8082
impl AppState {
81-
pub fn new(url: String, max_disconnected_time: Option<u64>) -> Self {
83+
pub fn new(
84+
url: String,
85+
max_disconnected_time: Option<u64>,
86+
override_protocol_version: Option<ProtocolVersion>,
87+
) -> Self {
8288
Self {
8389
url,
8490
max_disconnected_time,
91+
override_protocol_version,
8592
disconnected_since: None,
8693
state: ProxyState::Connecting,
8794
connect_tries: 0,
@@ -286,6 +293,7 @@ impl AppState {
286293
"Initial connection successful, received init response. Waiting for client initialized."
287294
);
288295
self.state = ProxyState::WaitingForClientInitialized;
296+
message = self.maybe_overwrite_protocol_version(message);
289297
}
290298
}
291299
// --- End Initialization Response Handling ---
@@ -537,4 +545,25 @@ impl AppState {
537545
// Not a response/error, return Some(original_message)
538546
Some(message)
539547
}
548+
549+
fn maybe_overwrite_protocol_version(
550+
&mut self,
551+
message: ServerJsonRpcMessage,
552+
) -> ServerJsonRpcMessage {
553+
if let Some(protocol_version) = &self.override_protocol_version {
554+
match message {
555+
ServerJsonRpcMessage::Response(mut resp) => {
556+
if let ServerResult::InitializeResult(mut initialize_result) = resp.result {
557+
initialize_result.protocol_version = protocol_version.clone();
558+
resp.result = ServerResult::InitializeResult(initialize_result);
559+
return ServerJsonRpcMessage::Response(resp);
560+
}
561+
ServerJsonRpcMessage::Response(resp)
562+
}
563+
other => other,
564+
}
565+
} else {
566+
message
567+
}
568+
}
540569
}

0 commit comments

Comments
 (0)