Skip to content

Commit cc45f1c

Browse files
authored
feat: refactor and improve middleware pipeline (#114)
* feat: refactor and improve middleware pipeline * fix: typo
1 parent 8fada34 commit cc45f1c

File tree

15 files changed

+1457
-574
lines changed

15 files changed

+1457
-574
lines changed

crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl SnowflakeIdGenerator {
5656
.expect("invalid system time!")
5757
.as_millis() as u64;
5858

59-
now - *SHORTER_EPOCH
59+
now.saturating_sub(*SHORTER_EPOCH)
6060
}
6161

6262
fn next_id(&self) -> u64 {

crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::hyper_servers::error::TransportServerResult;
22
use crate::mcp_http::{McpAppState, McpHttpHandler};
33
use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router};
4+
use http::{HeaderMap, Method, Uri};
45
use std::sync::Arc;
56

67
#[derive(Clone)]
@@ -35,13 +36,16 @@ pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router<Arc<McpA
3536
/// # Returns
3637
/// * `TransportServerResult<impl IntoResponse>` - The SSE response stream or an error
3738
pub async fn handle_sse(
39+
headers: HeaderMap,
40+
uri: Uri,
3841
Extension(sse_message_endpoint): Extension<SseMessageEndpoint>,
3942
Extension(http_handler): Extension<Arc<McpHttpHandler>>,
4043
State(state): State<Arc<McpAppState>>,
4144
) -> TransportServerResult<impl IntoResponse> {
4245
let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint;
46+
let request = McpHttpHandler::create_request(Method::GET, uri, headers, None);
4347
let generic_response = http_handler
44-
.handle_sse_connection(state.clone(), Some(&sse_message_endpoint))
48+
.handle_sse_connection(request, state.clone(), Some(&sse_message_endpoint))
4549
.await?;
4650
let (parts, body) = generic_response.into_parts();
4751
let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body));

crates/rust-mcp-sdk/src/hyper_servers/server.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use crate::{
22
error::SdkResult,
33
id_generator::{FastIdGenerator, UuidGenerator},
44
mcp_http::{
5-
utils::{
5+
http_utils::{
66
DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
77
},
8+
middleware::dns_rebind_protector::DnsRebindProtector,
89
McpAppState, McpHttpHandler,
910
},
1011
mcp_server::hyper_runtime::HyperRuntime,
@@ -203,6 +204,11 @@ impl HyperServerOptions {
203204
.as_deref()
204205
.unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
205206
}
207+
208+
pub fn needs_dns_protection(&self) -> bool {
209+
self.dns_rebinding_protection
210+
&& (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
211+
}
206212
}
207213

208214
/// Default implementation for HyperServerOptions
@@ -270,13 +276,18 @@ impl HyperServer {
270276
ping_interval: server_options.ping_interval,
271277
transport_options: Arc::clone(&server_options.transport_options),
272278
enable_json_response: server_options.enable_json_response.unwrap_or(false),
273-
allowed_hosts: server_options.allowed_hosts.take(),
274-
allowed_origins: server_options.allowed_origins.take(),
275-
dns_rebinding_protection: server_options.dns_rebinding_protection,
276279
event_store: server_options.event_store.as_ref().map(Arc::clone),
277280
});
278281

279-
let http_handler = McpHttpHandler::new(); //TODO: add auth handlers
282+
let mut http_handler = McpHttpHandler::new();
283+
284+
if server_options.needs_dns_protection() {
285+
http_handler.add_middleware(DnsRebindProtector::new(
286+
server_options.allowed_hosts.take(),
287+
server_options.allowed_origins.take(),
288+
));
289+
}
290+
280291
let app = app_routes(Arc::clone(&state), &server_options, http_handler);
281292
Self {
282293
app,

crates/rust-mcp-sdk/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ pub mod error;
33
mod hyper_servers;
44
mod mcp_handlers;
55
#[cfg(feature = "hyper-server")]
6-
pub(crate) mod mcp_http;
6+
pub mod mcp_http;
77
mod mcp_macros;
88
mod mcp_runtimes;
99
mod mcp_traits;
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
mod app_state;
2+
pub(crate) mod http_utils;
23
mod mcp_http_handler;
3-
pub(crate) mod mcp_http_utils;
4-
5-
mod mcp_http_middleware; //TODO:
4+
pub mod middleware;
5+
mod types;
66

77
pub use app_state::*;
8+
pub use http_utils::*;
89
pub use mcp_http_handler::*;
9-
pub use mcp_http_middleware::Middleware;
10+
pub use types::*;
1011

11-
pub(crate) mod utils {
12-
pub use super::mcp_http_utils::*;
13-
}
12+
pub use middleware::Middleware;

crates/rust-mcp-sdk/src/mcp_http/app_state.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,7 @@ pub struct McpAppState {
1919
pub ping_interval: Duration,
2020
pub transport_options: Arc<TransportOptions>,
2121
pub enable_json_response: bool,
22-
/// List of allowed host header values for DNS rebinding protection.
23-
/// If not specified, host validation is disabled.
24-
pub allowed_hosts: Option<Vec<String>>,
25-
/// List of allowed origin header values for DNS rebinding protection.
26-
/// If not specified, origin validation is disabled.
27-
pub allowed_origins: Option<Vec<String>>,
28-
/// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
29-
/// Default is false for backwards compatibility.
30-
pub dns_rebinding_protection: bool,
3122
/// Event store for resumability support
3223
/// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
3324
pub event_store: Option<Arc<dyn EventStore>>,
3425
}
35-
36-
impl McpAppState {
37-
pub fn needs_dns_protection(&self) -> bool {
38-
self.dns_rebinding_protection
39-
&& (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
40-
}
41-
}

crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs renamed to crates/rust-mcp-sdk/src/mcp_http/http_utils.rs

Lines changed: 28 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::mcp_http::types::GenericBody;
12
use crate::schema::schema_utils::{ClientMessage, SdkError};
23
use crate::{
34
error::SdkResult,
@@ -11,10 +12,10 @@ use crate::{
1112
use axum::http::HeaderValue;
1213
use bytes::Bytes;
1314
use futures::stream;
14-
use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE, HOST, ORIGIN};
15+
use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE};
1516
use http_body::Frame;
1617
use http_body_util::StreamBody;
17-
use http_body_util::{combinators::BoxBody, BodyExt, Full};
18+
use http_body_util::{BodyExt, Full};
1819
use hyper::{HeaderMap, StatusCode};
1920
use rust_mcp_transport::{
2021
EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR,
@@ -32,8 +33,6 @@ pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
3233
pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
3334
const DUPLEX_BUFFER_SIZE: usize = 8192;
3435

35-
pub type GenericBody = BoxBody<Bytes, TransportServerError>;
36-
3736
/// Creates an empty HTTP response body.
3837
///
3938
/// This function constructs a `GenericBody` containing an empty `Bytes` buffer,
@@ -45,6 +44,20 @@ pub fn empty_response() -> GenericBody {
4544
.boxed()
4645
}
4746

47+
pub fn build_response(
48+
status_code: StatusCode,
49+
payload: String,
50+
) -> Result<http::Response<GenericBody>, TransportServerError> {
51+
let body = Full::new(Bytes::from(payload))
52+
.map_err(|err| TransportServerError::HttpError(err.to_string()))
53+
.boxed();
54+
55+
http::Response::builder()
56+
.status(status_code)
57+
.body(body)
58+
.map_err(|err| TransportServerError::HttpError(err.to_string()))
59+
}
60+
4861
/// Creates an initial SSE event that returns the messages endpoint
4962
///
5063
/// Constructs an SSE event containing the messages endpoint URL with the session ID.
@@ -251,7 +264,7 @@ fn is_result(json_str: &str) -> Result<bool, serde_json::Error> {
251264
}
252265
}
253266

254-
pub async fn create_standalone_stream(
267+
pub(crate) async fn create_standalone_stream(
255268
session_id: SessionId,
256269
last_event_id: Option<EventId>,
257270
state: Arc<McpAppState>,
@@ -287,7 +300,7 @@ pub async fn create_standalone_stream(
287300
Ok(response)
288301
}
289302

290-
pub async fn start_new_session(
303+
pub(crate) async fn start_new_session(
291304
state: Arc<McpAppState>,
292305
payload: &str,
293306
) -> TransportServerResult<http::Response<GenericBody>> {
@@ -421,7 +434,7 @@ async fn single_shot_stream(
421434
}
422435
}
423436

424-
pub async fn process_incoming_message_return(
437+
pub(crate) async fn process_incoming_message_return(
425438
session_id: SessionId,
426439
state: Arc<McpAppState>,
427440
payload: &str,
@@ -446,7 +459,7 @@ pub async fn process_incoming_message_return(
446459
}
447460
}
448461

449-
pub async fn process_incoming_message(
462+
pub(crate) async fn process_incoming_message(
450463
session_id: SessionId,
451464
state: Arc<McpAppState>,
452465
payload: &str,
@@ -499,11 +512,11 @@ pub async fn process_incoming_message(
499512
}
500513
}
501514

502-
pub fn is_empty_sse_message(sse_payload: &str) -> bool {
515+
pub(crate) fn is_empty_sse_message(sse_payload: &str) -> bool {
503516
sse_payload.is_empty() || sse_payload.trim() == ":"
504517
}
505518

506-
pub async fn delete_session(
519+
pub(crate) async fn delete_session(
507520
session_id: SessionId,
508521
state: Arc<McpAppState>,
509522
) -> TransportServerResult<http::Response<GenericBody>> {
@@ -529,7 +542,7 @@ pub async fn delete_session(
529542
}
530543
}
531544

532-
pub fn acceptable_content_type(headers: &HeaderMap) -> bool {
545+
pub(crate) fn acceptable_content_type(headers: &HeaderMap) -> bool {
533546
let accept_header = headers
534547
.get("content-type")
535548
.and_then(|val| val.to_str().ok())
@@ -539,7 +552,7 @@ pub fn acceptable_content_type(headers: &HeaderMap) -> bool {
539552
.any(|val| val.trim().starts_with("application/json"))
540553
}
541554

542-
pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
555+
pub(crate) fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
543556
let protocol_version_header = headers
544557
.get(MCP_PROTOCOL_VERSION_HEADER)
545558
.and_then(|val| val.to_str().ok())
@@ -553,7 +566,7 @@ pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()
553566
validate_mcp_protocol_version(protocol_version_header)
554567
}
555568

556-
pub fn accepts_event_stream(headers: &HeaderMap) -> bool {
569+
pub(crate) fn accepts_event_stream(headers: &HeaderMap) -> bool {
557570
let accept_header = headers
558571
.get(ACCEPT)
559572
.and_then(|val| val.to_str().ok())
@@ -564,7 +577,7 @@ pub fn accepts_event_stream(headers: &HeaderMap) -> bool {
564577
.any(|val| val.trim().starts_with("text/event-stream"))
565578
}
566579

567-
pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
580+
pub(crate) fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
568581
let accept_header = headers
569582
.get(ACCEPT)
570583
.and_then(|val| val.to_str().ok())
@@ -593,53 +606,6 @@ pub fn error_response(
593606
.map_err(|err| TransportServerError::HttpError(err.to_string()))
594607
}
595608

596-
// Protect against DNS rebinding attacks by validating Host and Origin headers.
597-
pub(crate) async fn protect_dns_rebinding(
598-
headers: &http::HeaderMap,
599-
state: Arc<McpAppState>,
600-
) -> Result<(), SdkError> {
601-
if !state.needs_dns_protection() {
602-
// If protection is not needed, pass the request to the next handler
603-
return Ok(());
604-
}
605-
606-
if let Some(allowed_hosts) = state.allowed_hosts.as_ref() {
607-
if !allowed_hosts.is_empty() {
608-
let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else {
609-
return Err(SdkError::bad_request().with_message("Invalid Host header: [unknown] "));
610-
};
611-
612-
if !allowed_hosts
613-
.iter()
614-
.any(|allowed| allowed.eq_ignore_ascii_case(host))
615-
{
616-
return Err(SdkError::bad_request()
617-
.with_message(format!("Invalid Host header: \"{host}\" ").as_str()));
618-
}
619-
}
620-
}
621-
622-
if let Some(allowed_origins) = state.allowed_origins.as_ref() {
623-
if !allowed_origins.is_empty() {
624-
let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else {
625-
return Err(
626-
SdkError::bad_request().with_message("Invalid Origin header: [unknown] ")
627-
);
628-
};
629-
630-
if !allowed_origins
631-
.iter()
632-
.any(|allowed| allowed.eq_ignore_ascii_case(origin))
633-
{
634-
return Err(SdkError::bad_request()
635-
.with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()));
636-
}
637-
}
638-
}
639-
640-
Ok(())
641-
}
642-
643609
/// Extracts the value of a query parameter from an HTTP request by key.
644610
///
645611
/// This function parses the query string from the request URI and searches
@@ -653,7 +619,7 @@ pub(crate) async fn protect_dns_rebinding(
653619
/// * `Some(String)` containing the value of the query parameter if found.
654620
/// * `None` if the query string is missing or the key is not present.
655621
///
656-
pub fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
622+
pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
657623
request.uri().query().and_then(|query| {
658624
for pair in query.split('&') {
659625
let mut split = pair.splitn(2, '=');

0 commit comments

Comments
 (0)