diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
index 0b71ddc1f..d6780993f 100644
--- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java
@@ -36,6 +36,7 @@
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
+import org.springframework.web.util.UriComponentsBuilder;
/**
* Server-side implementation of the Model Context Protocol (MCP) transport layer using
@@ -87,6 +88,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
*/
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
+ public static final String SESSION_ID = "sessionId";
+
/**
* Default SSE endpoint path as specified by the MCP transport specification.
*/
@@ -275,9 +278,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
this.sessions.put(sessionId, session);
try {
- sseBuilder.id(sessionId)
- .event(ENDPOINT_EVENT_TYPE)
- .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
+ sseBuilder.id(sessionId).event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId));
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
@@ -292,6 +293,14 @@ private ServerResponse handleSseConnection(ServerRequest request) {
}
}
+ private String buildEndpointUrl(String sessionId) {
+ return UriComponentsBuilder.fromUriString(baseUrl)
+ .path(messageEndpoint)
+ .queryParam(SESSION_ID, sessionId)
+ .build()
+ .toUriString();
+ }
+
/**
* Handles incoming JSON-RPC messages from clients. This method:
*
@@ -308,11 +317,11 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}
- if (request.param("sessionId").isEmpty()) {
+ if (request.param(SESSION_ID).isEmpty()) {
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
}
- String sessionId = request.param("sessionId").get();
+ String sessionId = request.param(SESSION_ID).get();
McpServerSession session = sessions.get(sessionId);
if (session == null) {
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java
new file mode 100644
index 000000000..17c0f9345
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.common.McpTransportContext;
+import io.modelcontextprotocol.json.McpJsonMapper;
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.server.TestUtil;
+import io.modelcontextprotocol.server.TomcatTestUtil;
+import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for WebMvcSseServerTransportProvider
+ *
+ * @author lance
+ */
+class WebMvcSseServerTransportProviderTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final String CUSTOM_CONTEXT_PATH = "/";
+
+ private static final String MESSAGE_ENDPOINT = "/mcp/message";
+
+ private WebMvcSseServerTransportProvider mcpServerTransportProvider;
+
+ McpClient.SyncSpec clientBuilder;
+
+ private TomcatTestUtil.TomcatServer tomcatServer;
+
+ @BeforeEach
+ public void before() {
+ tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
+
+ try {
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
+ .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
+ .build();
+
+ clientBuilder = McpClient.sync(transport);
+ mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
+ }
+
+ @Test
+ void validBaseUrl() {
+ McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
+ try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
+ .build()) {
+ assertThat(client.initialize()).isNotNull();
+ }
+ }
+
+ @AfterEach
+ public void after() {
+ if (mcpServerTransportProvider != null) {
+ mcpServerTransportProvider.closeGracefully().block();
+ }
+ if (tomcatServer.appContext() != null) {
+ tomcatServer.appContext().close();
+ }
+ if (tomcatServer.tomcat() != null) {
+ try {
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ @Configuration
+ @EnableWebMvc
+ static class TestConfig {
+
+ @Bean
+ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
+
+ return WebMvcSseServerTransportProvider.builder()
+ .baseUrl("http://localhost:" + PORT + "/")
+ .messageEndpoint(MESSAGE_ENDPOINT)
+ .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
+ .jsonMapper(McpJsonMapper.getDefault())
+ .contextExtractor(req -> McpTransportContext.EMPTY)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+}