diff --git a/Cargo.lock b/Cargo.lock index 8ec6092..02f208c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,7 +178,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-targets 0.52.6", + "windows-targets", ] [[package]] @@ -409,9 +409,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] @@ -603,6 +603,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-serde" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" +dependencies = [ + "http", + "serde", +] + [[package]] name = "httparse" version = "1.10.1" @@ -672,22 +682,28 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.11" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ + "base64", "bytes", "futures-channel", + "futures-core", "futures-util", "http", "http-body", "hyper", + "ipnet", "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -840,9 +856,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -875,6 +891,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -957,6 +983,7 @@ dependencies = [ "axum", "clap", "futures", + "http-serde", "openssl-sys", "reqwest", "rmcp", @@ -966,6 +993,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "url", "uuid", ] @@ -1144,7 +1172,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.6", + "windows-targets", ] [[package]] @@ -1155,9 +1183,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project-lite" @@ -1339,9 +1367,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.15" +version = "0.12.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" dependencies = [ "base64", "bytes", @@ -1356,36 +1384,32 @@ dependencies = [ "hyper-rustls", "hyper-tls", "hyper-util", - "ipnet", "js-sys", "log", "mime", "native-tls", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", "rustls", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", - "system-configuration", "tokio", "tokio-native-tls", "tokio-rustls", "tokio-util", "tower", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 0.26.11", - "windows-registry", + "webpki-roots 1.0.0", ] [[package]] @@ -1485,15 +1509,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.11.0" @@ -1592,18 +1607,28 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -1623,14 +1648,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] @@ -1944,6 +1970,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2034,13 +2078,14 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -2283,7 +2328,7 @@ dependencies = [ "windows-interface", "windows-link", "windows-result", - "windows-strings 0.4.0", + "windows-strings", ] [[package]] @@ -2336,13 +2381,13 @@ dependencies = [ [[package]] name = "windows-registry" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" +checksum = "ad1da3e436dc7653dfdf3da67332e22bff09bb0e28b0239e1624499c7830842e" dependencies = [ + "windows-link", "windows-result", - "windows-strings 0.3.1", - "windows-targets 0.53.0", + "windows-strings", ] [[package]] @@ -2354,15 +2399,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-strings" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" -dependencies = [ - "windows-link", -] - [[package]] name = "windows-strings" version = "0.4.0" @@ -2378,7 +2414,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.6", + "windows-targets", ] [[package]] @@ -2387,7 +2423,7 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets 0.52.6", + "windows-targets", ] [[package]] @@ -2396,30 +2432,14 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" -dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] @@ -2428,96 +2448,48 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" - [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" - [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" - [[package]] name = "wit-bindgen-rt" version = "0.39.0" diff --git a/Cargo.toml b/Cargo.toml index 8798a29..643e388 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,9 +12,9 @@ rmcp = { git = "https://github.com/SteffenDE/mcp-rust-sdk.git", branch = "sd-bad "transport-sse-client-reqwest", "transport-streamable-http-client-reqwest", "transport-worker", - "transport-child-process" + "transport-child-process", ] } -clap = { version = "4.5.37", features = ["derive"] } +clap = { version = "4.5.37", features = ["derive", "env"] } tokio = { version = "1", features = ["full"] } tracing = "0.1.41" tracing-subscriber = "0.3.19" @@ -22,7 +22,10 @@ anyhow = "1.0.98" uuid = { version = "1.6", features = ["v7", "fast-rng"] } futures = "0.3.31" tokio-util = "0.7.15" -reqwest = { version = "0.12", features = ["json", "stream"] } +reqwest = { version = "0.12.24", features = ["json", "stream"] } +http-serde = "2.1.1" +serde_json = "1.0.145" +url = "2.5.7" [dependencies.openssl-sys] version = "0.9" @@ -38,7 +41,7 @@ rmcp = { git = "https://github.com/SteffenDE/mcp-rust-sdk.git", branch = "sd-bad "transport-sse-server", "transport-child-process", "transport-streamable-http-server", - "macros" + "macros", ] } axum = { version = "0.8", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } diff --git a/src/cli.rs b/src/cli.rs index 157f862..a0a232f 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,11 +1,36 @@ +use anyhow::Context; use clap::Parser; +use reqwest::header::HeaderMap; +use url::Url; + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +#[value(rename_all = "kebab-case")] +pub enum TransportType { + /// Automatically determine the transport type based on the server's response + #[default] + Auto, + /// Use the streamable HTTP transport + StreamableHttp, + /// Use the SSE transport + Sse, +} #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Args { /// The URL of the SSE endpoint to connect to - #[arg(value_name = "URL")] - pub sse_url: Option, + #[arg(value_name = "URL", env = "SSE_URL", value_parser = parse_url)] + pub sse_url: Url, + + #[arg(long, env = "MCP_HEADERS", value_parser = parse_header_map)] + /// Headers to send to the server + /// This is a JSON object of header name to header value. + /// Example: `{"Authorization": "Bearer 1234567890"}` + pub headers: Option, + + /// The transport type to use + #[arg(long, env = "TRANSPORT_TYPE", default_value = "auto")] + pub transport_type: TransportType, /// Enable debug logging #[arg(long)] @@ -23,3 +48,18 @@ pub struct Args { /// Override the protocol version returned to the client pub override_protocol_version: Option, } + +fn parse_header_map(s: &str) -> Result { + let headers = { + let mut de = serde_json::Deserializer::from_str(s); + http_serde::header_map::deserialize(&mut de) + .context("failed to parse headers") + .map_err(|e| e.to_string())? + }; + Ok(headers) +} + +fn parse_url(s: &str) -> Result { + let url = Url::parse(s).map_err(|e| e.to_string())?; + Ok(url) +} diff --git a/src/core.rs b/src/core.rs index 5815b08..57e700c 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,6 +1,6 @@ use crate::state::{AppState, BufferMode, ProxyState, ReconnectFailureReason}; use crate::{DISCONNECTED_ERROR_CODE, SseClientType, StdoutSink, TRANSPORT_SEND_ERROR_CODE}; -use anyhow::{Result, anyhow}; +use anyhow::{Context, Result, anyhow}; use futures::FutureExt; use futures::SinkExt; use rmcp::model::{ @@ -35,32 +35,54 @@ pub(crate) async fn reply_disconnected(id: &RequestId, stdout_sink: &mut StdoutS } pub(crate) async fn connect(app_state: &AppState) -> Result { - // this function should try sending a POST request to the sse_url and see if - // the server responds with 405 method not supported. If so, it should call - // connect_with_sse, otherwise it should call connect_with_streamable. - let result = reqwest::Client::new() - .post(app_state.url.clone()) - .header("Accept", "application/json,text/event-stream") - .header("Content-Type", "application/json") - .body(r#"{"jsonrpc":"2.0","id":"init","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#) - .send() - .await?; - - if result.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { - debug!("Server responded with 405, using SSE transport"); - return connect_with_sse(app_state).await; - } else if result.status().is_success() { - debug!("Server responded successfully, using streamable transport"); - return connect_with_streamable(app_state).await; - } else { - error!("Server returned unexpected status: {}", result.status()); - anyhow::bail!("Server returned unexpected status: {}", result.status()); + use crate::cli::TransportType::*; + match app_state.transport_type { + Auto => { + // this function should try sending a POST request to the sse_url and see if + // the server responds with 405 method not supported. If so, it should call + // connect_with_sse, otherwise it should call connect_with_streamable. + let mut headers = app_state.headers.clone().unwrap_or_default(); + headers.insert( + "Accept", + "application/json,text/event-stream".parse().unwrap(), + ); + headers.insert("Content-Type", "application/json".parse().unwrap()); + let result = reqwest::Client::new() + .post(app_state.url.clone()) + .headers(headers) + .body(r#"{"jsonrpc":"2.0","id":"init","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#) + .send() + .await?; + if result.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + debug!("Server responded with 405, using SSE transport"); + return connect_with_sse(app_state).await; + } else if result.status().is_success() { + debug!("Server responded successfully, using streamable transport"); + return connect_with_streamable(app_state).await; + } else { + error!("Server returned unexpected status: {}", result.status()); + anyhow::bail!("Server returned unexpected status: {}", result.status()); + } + } + StreamableHttp => { + debug!("Using streamable transport"); + return connect_with_streamable(app_state).await; + } + Sse => { + debug!("Using SSE transport"); + return connect_with_sse(app_state).await; + } } } pub(crate) async fn connect_with_streamable(app_state: &AppState) -> Result { + let mut builder = reqwest::Client::builder(); + if let Some(headers) = app_state.headers.clone() { + builder = builder.default_headers(headers); + } + let result = rmcp::transport::StreamableHttpClientTransport::with_client( - reqwest::Client::default(), + builder.build().context("failed to build reqwest client")?, rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig { uri: app_state.url.clone().into(), // we don't want the sdk to perform any retries @@ -75,8 +97,13 @@ pub(crate) async fn connect_with_streamable(app_state: &AppState) -> Result Result { + let mut builder = reqwest::Client::builder(); + if let Some(headers) = app_state.headers.clone() { + builder = builder.default_headers(headers); + } + let result = rmcp::transport::SseClientTransport::start_with_client( - reqwest::Client::default(), + builder.build().context("failed to build reqwest client")?, rmcp::transport::sse_client::SseClientConfig { sse_endpoint: app_state.url.clone().into(), // we don't want the sdk to perform any retries diff --git a/src/main.rs b/src/main.rs index 46d72d9..02b4304 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,6 @@ use rmcp::{ model::{ClientJsonRpcMessage, ErrorCode, ProtocolVersion, ServerJsonRpcMessage}, transport::{StreamableHttpClientTransport, Transport, sse_client::SseClientTransport}, }; -use std::env; use tokio::io::{Stdin, Stdout}; use tokio::time::{Duration, Instant, sleep}; use tokio_util::codec::{FramedRead, FramedWrite}; @@ -103,7 +102,7 @@ async fn connect_with_retry(app_state: &AppState, delay: Duration) -> Result Result<()> { - let args = Args::parse(); + let mut args = Args::parse(); let log_level = if args.debug { tracing::Level::DEBUG } else { @@ -117,15 +116,7 @@ async fn main() -> Result<()> { tracing::subscriber::set_global_default(subscriber).context("Failed to set up logging")?; - // Get the SSE URL from args or environment - let sse_url = match args.sse_url { - Some(url) => url, - None => env::var("SSE_URL").context( - "Either the URL must be passed as the first argument or the SSE_URL environment variable must be set", - )?, - }; - - debug!("Starting MCP proxy with URL: {}", sse_url); + debug!("Starting MCP proxy with URL: {}", args.sse_url); debug!("Max disconnected time: {:?}s", args.max_disconnected_time); // Parse protocol version override if provided @@ -149,9 +140,12 @@ async fn main() -> Result<()> { let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(10); let (timer_tx, mut timer_rx) = tokio::sync::mpsc::channel(10); + let sse_url = args.sse_url.clone(); // Initialize application state let mut app_state = AppState::new( - sse_url.clone(), + sse_url.to_string(), + args.transport_type, + args.headers.take(), args.max_disconnected_time, override_protocol_version, ); diff --git a/src/state.rs b/src/state.rs index 50f23da..196cc6b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,4 @@ +use crate::cli::TransportType; use crate::core::{ flush_buffer_with_errors, generate_id, initiate_post_reconnect_handshake, process_buffered_messages, process_client_request, reply_disconnected, @@ -6,6 +7,7 @@ use crate::core::{ use crate::{SseClientType, StdoutSink}; use anyhow::Result; use futures::SinkExt; +use reqwest::header::HeaderMap; use rmcp::model::{ ClientJsonRpcMessage, ClientNotification, ClientRequest, EmptyResult, InitializedNotification, InitializedNotificationMethod, ProtocolVersion, RequestId, ServerJsonRpcMessage, ServerResult, @@ -47,6 +49,8 @@ pub enum ProxyState { pub struct AppState { /// URL of the SSE server pub url: String, + pub headers: Option, + pub transport_type: TransportType, /// Maximum time to try reconnecting in seconds (None = infinity) pub max_disconnected_time: Option, /// Override protocol version @@ -82,11 +86,15 @@ pub struct AppState { impl AppState { pub fn new( url: String, + transport_type: TransportType, + headers: Option, max_disconnected_time: Option, override_protocol_version: Option, ) -> Self { Self { url, + headers, + transport_type, max_disconnected_time, override_protocol_version, disconnected_since: None,