Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.time.Duration;
import java.util.Map;

class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler {
Expand All @@ -21,10 +23,13 @@ class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler {

Map<String, McpStatelessNotificationHandler> notificationHandlers;

Duration requestTimeout;

public DefaultMcpStatelessServerHandler(Map<String, McpStatelessRequestHandler<?>> requestHandlers,
Map<String, McpStatelessNotificationHandler> notificationHandlers) {
Map<String, McpStatelessNotificationHandler> notificationHandlers, Duration requestTimeout) {
this.requestHandlers = requestHandlers;
this.notificationHandlers = notificationHandlers;
this.requestTimeout = requestTimeout;
}

@Override
Expand All @@ -35,6 +40,8 @@ public Mono<McpSchema.JSONRPCResponse> handleRequest(McpTransportContext transpo
return Mono.error(new McpError("Missing handler for request type: " + request.method()));
}
return requestHandler.handle(transportContext, request.params())
.subscribeOn(Schedulers.boundedElastic())
.timeout(this.requestTimeout)
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
.onErrorResume(t -> {
McpSchema.JSONRPCResponse.JSONRPCError error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public class McpStatelessAsyncServer {

private final JsonSchemaValidator jsonSchemaValidator;

private final Duration requestTimeout;

McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper,
McpStatelessServerFeatures.Async features, Duration requestTimeout,
McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) {
Expand All @@ -93,6 +95,7 @@ public class McpStatelessAsyncServer {
this.completions.putAll(features.completions());
this.uriTemplateManagerFactory = uriTemplateManagerFactory;
this.jsonSchemaValidator = jsonSchemaValidator;
this.requestTimeout = requestTimeout;

Map<String, McpStatelessRequestHandler<?>> requestHandlers = new HashMap<>();

Expand Down Expand Up @@ -129,7 +132,8 @@ public class McpStatelessAsyncServer {

this.protocolVersions = new ArrayList<>(mcpTransport.protocolVersions());

McpStatelessServerHandler handler = new DefaultMcpStatelessServerHandler(requestHandlers, Map.of());
McpStatelessServerHandler handler = new DefaultMcpStatelessServerHandler(requestHandlers, Map.of(),
this.requestTimeout);
mcpTransport.setMcpHandler(handler);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,4 +647,101 @@ private double evaluateExpression(String expression) {
};
}

// ---------------------------------------
// Timeout Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testRequestTimeoutWithSlowTool(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

// Create a tool that takes longer than the timeout
McpStatelessServerFeatures.SyncToolSpecification slowTool = new McpStatelessServerFeatures.SyncToolSpecification(
Tool.builder()
.name("slow-tool")
.description("A tool that takes too long")
.inputSchema(EMPTY_JSON_SCHEMA)
.build(),
(transportContext, request) -> {
try {
// Sleep for 3 seconds, which is longer than our timeout
Thread.sleep(3000);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted", e);
}
return new CallToolResult(List.of(new TextContent("This should not be reached")), null);
});

// Create server with a 1-second request timeout
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.requestTimeout(Duration.ofSeconds(1))
.tools(slowTool)
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Call the slow tool - should timeout and throw an exception
org.assertj.core.api.Assertions
.assertThatThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("slow-tool", Map.of())))
.isInstanceOf(io.modelcontextprotocol.spec.McpError.class)
.satisfies(error -> {
String message = error.getMessage().toLowerCase();
assertThat(message).containsAnyOf("timeout", "timed out", "did not observe");
});
}
finally {
mcpServer.close();
}
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testRequestTimeoutWithFastTool(String clientType) {
var clientBuilder = clientBuilders.get(clientType);

// Create a tool that completes quickly
McpStatelessServerFeatures.SyncToolSpecification fastTool = new McpStatelessServerFeatures.SyncToolSpecification(
Tool.builder()
.name("fast-tool")
.description("A tool that completes quickly")
.inputSchema(EMPTY_JSON_SCHEMA)
.build(),
(transportContext, request) -> {
return new CallToolResult(List.of(new TextContent("Fast response")), null);
});

// Create server with a 5-second request timeout (plenty of time)
var mcpServer = McpServer.sync(mcpStatelessServerTransport)
.serverInfo("test-server", "1.0.0")
.capabilities(ServerCapabilities.builder().tools(true).build())
.requestTimeout(Duration.ofSeconds(5))
.tools(fastTool)
.build();

try (var mcpClient = clientBuilder.build()) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Call the fast tool - should succeed
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("fast-tool", Map.of()));

// Verify that we got a successful response
assertThat(response).isNotNull();
assertThat(response.isError()).isNotEqualTo(Boolean.TRUE);
assertThat(response.content()).isNotEmpty();

String message = ((TextContent) response.content().get(0)).text();
assertThat(message).isEqualTo("Fast response");
}
finally {
mcpServer.close();
}
}

}