Skip to content

Commit 6a38ac6

Browse files
committed
update tests
1 parent c95989c commit 6a38ac6

File tree

1 file changed

+128
-79
lines changed

1 file changed

+128
-79
lines changed

tests/advanced_test.rs

Lines changed: 128 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,103 @@ use rmcp::{
66
object,
77
transport::{ConfigureCommandExt, TokioChildProcess},
88
};
9-
use std::{net::SocketAddr, time::Duration};
9+
use std::{
10+
net::SocketAddr,
11+
sync::{Arc, Mutex},
12+
time::Duration,
13+
};
1014
use tokio::{
1115
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
1216
time::{sleep, timeout},
1317
};
1418

19+
/// A guard that ensures processes are killed on drop, especially on test failures (panics)
20+
struct TestGuard {
21+
child: Option<tokio::process::Child>,
22+
server_handle: Option<tokio::process::Child>,
23+
stderr_buffer: Arc<Mutex<Vec<String>>>,
24+
}
25+
26+
impl TestGuard {
27+
fn new(
28+
child: tokio::process::Child,
29+
server_handle: tokio::process::Child,
30+
stderr_buffer: Arc<Mutex<Vec<String>>>,
31+
) -> Self {
32+
Self {
33+
child: Some(child),
34+
server_handle: Some(server_handle),
35+
stderr_buffer,
36+
}
37+
}
38+
}
39+
40+
impl Drop for TestGuard {
41+
fn drop(&mut self) {
42+
// If we're dropping because of a panic, print the stderr content
43+
if std::thread::panicking() {
44+
eprintln!("Test failed! Process stderr output:");
45+
for line in self.stderr_buffer.lock().unwrap().iter() {
46+
eprintln!("{}", line);
47+
}
48+
}
49+
50+
// Force kill both processes
51+
if let Some(mut child) = self.child.take() {
52+
let _ = child.start_kill();
53+
}
54+
if let Some(mut server_handle) = self.server_handle.take() {
55+
let _ = server_handle.start_kill();
56+
}
57+
}
58+
}
59+
60+
/// Spawns a proxy process with stdin, stdout, and stderr all captured
61+
async fn spawn_proxy(
62+
server_url: &str,
63+
extra_args: Vec<&str>,
64+
) -> Result<(
65+
tokio::process::Child,
66+
tokio::io::BufReader<tokio::process::ChildStdout>,
67+
tokio::io::BufReader<tokio::process::ChildStderr>,
68+
tokio::process::ChildStdin,
69+
)> {
70+
let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy");
71+
cmd.arg(server_url)
72+
.args(extra_args)
73+
.stdout(std::process::Stdio::piped())
74+
.stderr(std::process::Stdio::piped())
75+
.stdin(std::process::Stdio::piped());
76+
77+
let mut child = cmd.spawn()?;
78+
let stdin = child.stdin.take().unwrap();
79+
let stdout = BufReader::new(child.stdout.take().unwrap());
80+
let stderr = BufReader::new(child.stderr.take().unwrap());
81+
82+
Ok((child, stdout, stderr, stdin))
83+
}
84+
85+
/// Collects stderr lines in the background
86+
fn collect_stderr(
87+
mut stderr_reader: BufReader<tokio::process::ChildStderr>,
88+
) -> Arc<Mutex<Vec<String>>> {
89+
let stderr_buffer = Arc::new(Mutex::new(Vec::new()));
90+
let buffer_clone = stderr_buffer.clone();
91+
92+
tokio::spawn(async move {
93+
let mut line = String::new();
94+
while let Ok(bytes_read) = stderr_reader.read_line(&mut line).await {
95+
if bytes_read == 0 {
96+
break;
97+
}
98+
buffer_clone.lock().unwrap().push(line.clone());
99+
line.clear();
100+
}
101+
});
102+
103+
stderr_buffer
104+
}
105+
15106
// Creates a new SSE server for testing
16107
// Starts the echo-server as a subprocess
17108
async fn create_sse_server(
@@ -32,35 +123,27 @@ async fn create_sse_server(
32123

33124
tracing::debug!("cmd: {:?}", cmd);
34125

35-
// Start the process
36-
let child = cmd.spawn()?;
126+
// Start the process with stdout/stderr redirected to null
127+
let child = cmd
128+
.stdout(std::process::Stdio::null())
129+
.stderr(std::process::Stdio::null())
130+
.spawn()?;
37131

38132
// Give the server time to start up
39133
sleep(Duration::from_millis(500)).await;
40-
41134
tracing::info!("{} server started successfully", server_name);
42135

43136
Ok((child, url))
44137
}
45138

46139
async fn protocol_initialization(server_name: &str) -> Result<()> {
47140
const BIND_ADDRESS: &str = "127.0.0.1:8181";
48-
// Start the SSE server
49-
let (mut server_handle, server_url) =
50-
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
141+
let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
51142

52-
// Create a child process for the proxy
53-
let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy");
54-
cmd.arg(&server_url)
55-
.stdout(std::process::Stdio::piped())
56-
.stdin(std::process::Stdio::piped());
57-
58-
let mut child = cmd.spawn()?;
59-
60-
// Get stdin and stdout handles
61-
let mut stdin = child.stdin.take().unwrap();
62-
let stdout = child.stdout.take().unwrap();
63-
let mut reader = BufReader::new(stdout);
143+
// Create a child process for the proxy with stderr capture
144+
let (child, mut reader, stderr_reader, mut stdin) = spawn_proxy(&server_url, vec![]).await?;
145+
let stderr_buffer = collect_stderr(stderr_reader);
146+
let _guard = TestGuard::new(child, server_handle, stderr_buffer);
64147

65148
// Send initialization message
66149
let init_message = r#"{"jsonrpc":"2.0","id":"init-1","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#;
@@ -93,10 +176,6 @@ async fn protocol_initialization(server_name: &str) -> Result<()> {
93176
assert!(echo_response.contains("\"id\":\"call-1\""));
94177
assert!(echo_response.contains("Hey!"));
95178

96-
// Clean up
97-
child.kill().await?;
98-
server_handle.kill().await?;
99-
100179
Ok(())
101180
}
102181

@@ -119,22 +198,13 @@ async fn reconnection_handling(server_name: &str) -> Result<()> {
119198

120199
// Start the SSE server
121200
tracing::info!("Test: Starting initial SSE server");
122-
let (mut server_handle, server_url) =
123-
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
201+
let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
124202

125203
// Create a child process for the proxy
126204
tracing::info!("Test: Creating proxy process");
127-
let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy");
128-
cmd.arg(&server_url)
129-
.stdout(std::process::Stdio::piped())
130-
.stdin(std::process::Stdio::piped());
131-
132-
let mut child = cmd.spawn()?;
133-
134-
// Get stdin and stdout handles
135-
let mut stdin = child.stdin.take().unwrap();
136-
let stdout = child.stdout.take().unwrap();
137-
let mut reader = BufReader::new(stdout);
205+
let (child, mut reader, stderr_reader, mut stdin) = spawn_proxy(&server_url, vec![]).await?;
206+
let stderr_buffer = collect_stderr(stderr_reader);
207+
let mut test_guard = TestGuard::new(child, server_handle, stderr_buffer);
138208

139209
// Send initialization message
140210
let init_message = r#"{"jsonrpc":"2.0","id":"init-1","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#;
@@ -163,20 +233,25 @@ async fn reconnection_handling(server_name: &str) -> Result<()> {
163233
);
164234

165235
// Shutdown the server
166-
server_handle.kill().await?;
236+
if let Some(mut server) = test_guard.server_handle.take() {
237+
server.kill().await?;
238+
}
167239

168240
// Give the server time to shut down
169241
sleep(Duration::from_millis(1000)).await;
170242

171243
// Create a new server on the same address
172244
tracing::info!("Test: Starting new SSE server");
173-
let (mut server_handle, new_url) =
245+
let (new_server_handle, new_url) =
174246
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
175247
assert_eq!(
176248
server_url, new_url,
177249
"New server URL should match the original"
178250
);
179251

252+
// Update the test guard with the new server handle
253+
test_guard.server_handle = Some(new_server_handle);
254+
180255
// Give the proxy time to reconnect
181256
sleep(Duration::from_millis(3000)).await;
182257

@@ -197,11 +272,6 @@ async fn reconnection_handling(server_name: &str) -> Result<()> {
197272
"No response received after reconnection"
198273
);
199274

200-
// Clean up
201-
server_handle.kill().await?;
202-
sleep(Duration::from_millis(500)).await; // Give server time to shutdown
203-
child.kill().await?;
204-
205275
Ok(())
206276
}
207277

@@ -286,18 +356,10 @@ async fn initial_connection_retry(server_name: &str) -> Result<()> {
286356

287357
// 1. Start the proxy process BEFORE the server
288358
tracing::info!("Test: Starting proxy process...");
289-
let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy");
290-
cmd.arg(&server_url)
291-
.arg("--initial-retry-interval")
292-
.arg("1")
293-
.stdout(std::process::Stdio::piped())
294-
.stdin(std::process::Stdio::piped());
295-
let mut child = cmd.spawn()?;
359+
let (child, mut reader, stderr_reader, mut stdin) =
360+
spawn_proxy(&server_url, vec!["--initial-retry-interval", "1"]).await?;
296361

297-
// Get stdin and stdout handles
298-
let mut stdin = child.stdin.take().unwrap();
299-
let stdout = child.stdout.take().unwrap();
300-
let mut reader = BufReader::new(stdout);
362+
let stderr_buffer = collect_stderr(stderr_reader);
301363

302364
// 2. Wait for slightly longer than the proxy's retry delay
303365
// This ensures the proxy has attempted connection at least once and is retrying.
@@ -317,9 +379,11 @@ async fn initial_connection_retry(server_name: &str) -> Result<()> {
317379

318380
// 3. Start the SSE server AFTER the wait and AFTER sending init
319381
tracing::info!("Test: Starting SSE server on {}", BIND_ADDRESS);
320-
let (mut server_handle, returned_url) = create_sse_server(server_name, bind_addr).await?;
382+
let (server_handle, returned_url) = create_sse_server(server_name, bind_addr).await?;
321383
assert_eq!(server_url, returned_url, "Server URL mismatch");
322384

385+
let _test_guard = TestGuard::new(child, server_handle, stderr_buffer);
386+
323387
// 4. Proceed with initialization handshake (Proxy should now process buffered init)
324388
// Read the initialize response (with a timeout)
325389
tracing::info!("Test: Waiting for initialize response...");
@@ -377,12 +441,6 @@ async fn initial_connection_retry(server_name: &str) -> Result<()> {
377441
Err(_) => return Err(anyhow::anyhow!("Timed out waiting for echo response")),
378442
}
379443

380-
// 6. Cleanup
381-
tracing::info!("Test: Cleaning up...");
382-
child.kill().await?;
383-
server_handle.kill().await?;
384-
sleep(Duration::from_millis(500)).await; // Give server time to shutdown
385-
386444
tracing::info!("Test: Completed successfully");
387445
Ok(())
388446
}
@@ -405,23 +463,16 @@ async fn ping_when_disconnected(server_name: &str) -> Result<()> {
405463

406464
// 1. Start the SSE server
407465
tracing::info!("Test: Starting SSE server for ping test");
408-
let (mut server_handle, server_url) =
409-
create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
466+
let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?;
410467

411468
// Create a child process for the proxy
412469
tracing::info!("Test: Creating proxy process");
413-
let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy");
414-
cmd.arg(&server_url)
415-
.arg("--debug")
416-
.stdout(std::process::Stdio::piped())
417-
.stdin(std::process::Stdio::piped());
470+
let (child, mut reader, stderr_reader, mut stdin) =
471+
spawn_proxy(&server_url, vec!["--debug"]).await?;
418472

419-
let mut child = cmd.spawn()?;
473+
let stderr_buffer = collect_stderr(stderr_reader);
420474

421-
// Get stdin and stdout handles
422-
let mut stdin = child.stdin.take().unwrap();
423-
let stdout = child.stdout.take().unwrap();
424-
let mut reader = BufReader::new(stdout);
475+
let mut test_guard = TestGuard::new(child, server_handle, stderr_buffer);
425476

426477
// 2. Initializes everything
427478
let init_message = r#"{"jsonrpc":"2.0","id":"init-ping","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"ping-test","version":"0.1.0"}}}"#;
@@ -453,7 +504,9 @@ async fn ping_when_disconnected(server_name: &str) -> Result<()> {
453504

454505
// 3. Kills the SSE server
455506
tracing::info!("Test: Shutting down SSE server");
456-
server_handle.kill().await?;
507+
if let Some(mut server) = test_guard.server_handle.take() {
508+
server.kill().await?;
509+
}
457510
// Give the server time to shut down and the proxy time to notice
458511
sleep(Duration::from_secs(3)).await;
459512

@@ -482,16 +535,12 @@ async fn ping_when_disconnected(server_name: &str) -> Result<()> {
482535
Err(_) => panic!("Timed out waiting for ping response"),
483536
}
484537

485-
// Clean up
486-
tracing::info!("Test: Cleaning up proxy process");
487-
child.kill().await?;
488-
489538
Ok(())
490539
}
491540

492541
#[tokio::test]
493542
async fn test_ping_when_disconnected() -> Result<()> {
494-
// ping_when_disconnected("echo").await?;
543+
ping_when_disconnected("echo").await?;
495544
ping_when_disconnected("echo_streamable").await?;
496545

497546
Ok(())

0 commit comments

Comments
 (0)