Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
import org.springframework.web.util.UriComponentsBuilder;

/**
* Server-side implementation of the Model Context Protocol (MCP) transport layer using
Expand Down Expand Up @@ -87,6 +88,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
*/
public static final String ENDPOINT_EVENT_TYPE = "endpoint";

public static final String SESSION_ID = "sessionId";

/**
* Default SSE endpoint path as specified by the MCP transport specification.
*/
Expand Down Expand Up @@ -275,9 +278,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
this.sessions.put(sessionId, session);

try {
sseBuilder.id(sessionId)
.event(ENDPOINT_EVENT_TYPE)
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
sseBuilder.id(sessionId).event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId));
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
Expand All @@ -292,6 +293,14 @@ private ServerResponse handleSseConnection(ServerRequest request) {
}
}

private String buildEndpointUrl(String sessionId) {
return UriComponentsBuilder.fromUriString(baseUrl)
.path(messageEndpoint)
.queryParam(SESSION_ID, sessionId)
.build()
.toUriString();
}

/**
* Handles incoming JSON-RPC messages from clients. This method:
* <ul>
Expand All @@ -308,11 +317,11 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

if (request.param("sessionId").isEmpty()) {
if (request.param(SESSION_ID).isEmpty()) {
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
}

String sessionId = request.param("sessionId").get();
String sessionId = request.param(SESSION_ID).get();
McpServerSession session = sessions.get(sessionId);

if (session == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*/

package io.modelcontextprotocol.server.transport;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.TestUtil;
import io.modelcontextprotocol.server.TomcatTestUtil;
import io.modelcontextprotocol.spec.McpSchema;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.ServerResponse;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for WebMvcSseServerTransportProvider
*
* @author lance
*/
class WebMvcSseServerTransportProviderTests {

private static final int PORT = TestUtil.findAvailablePort();

private static final String CUSTOM_CONTEXT_PATH = "/";

private static final String MESSAGE_ENDPOINT = "/mcp/message";

private WebMvcSseServerTransportProvider mcpServerTransportProvider;

McpClient.SyncSpec clientBuilder;

private TomcatTestUtil.TomcatServer tomcatServer;

@BeforeEach
public void before() {
tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);

try {
tomcatServer.tomcat().start();
assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
}
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}

HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
.build();

clientBuilder = McpClient.sync(transport);
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
}

@Test
void validBaseUrl() {
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
.build()) {
assertThat(client.initialize()).isNotNull();
}
}

@AfterEach
public void after() {
if (mcpServerTransportProvider != null) {
mcpServerTransportProvider.closeGracefully().block();
}
if (tomcatServer.appContext() != null) {
tomcatServer.appContext().close();
}
if (tomcatServer.tomcat() != null) {
try {
tomcatServer.tomcat().stop();
tomcatServer.tomcat().destroy();
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to stop Tomcat", e);
}
}
}

@Configuration
@EnableWebMvc
static class TestConfig {

@Bean
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {

return WebMvcSseServerTransportProvider.builder()
.baseUrl("http://localhost:" + PORT + "/")
.messageEndpoint(MESSAGE_ENDPOINT)
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
.jsonMapper(McpJsonMapper.getDefault())
.contextExtractor(req -> McpTransportContext.EMPTY)
.build();
}

@Bean
public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportProvider transportProvider) {
return transportProvider.getRouterFunction();
}

}

}