From a84591f1e9249231c406c19828b3aef730cdfbbf Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Thu, 30 Oct 2025 18:09:45 -0700 Subject: [PATCH] Add robust input validation to MCP Server Signed-off-by: rithin-pullela-aws --- common/build.gradle | 2 + .../requests/server/MLMcpServerRequest.java | 116 +++++++++- .../server/MLMcpServerRequestTest.java | 208 +++++++++++++----- .../mcpserver/TransportMcpServerAction.java | 11 +- .../rest/mcpserver/RestMcpServerAction.java | 73 +++--- .../TransportMcpServerActionTests.java | 41 ++-- .../mcpserver/RestMcpServerActionTests.java | 84 ++++--- 7 files changed, 367 insertions(+), 168 deletions(-) diff --git a/common/build.gradle b/common/build.gradle index 5cbf100289..04f1592c00 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -28,6 +28,8 @@ dependencies { compileOnly group: 'com.google.code.gson', name: 'gson', version: "${versions.gson}" compileOnly group: 'org.json', name: 'json', version: '20231013' testImplementation group: 'org.json', name: 'json', version: '20231013' + compileOnly('io.modelcontextprotocol.sdk:mcp:0.12.1') + testImplementation('io.modelcontextprotocol.sdk:mcp:0.12.1') implementation('com.google.guava:guava:32.1.3-jre') { exclude group: 'com.google.guava', module: 'failureaccess' exclude group: 'com.google.code.findbugs', module: 'jsr305' diff --git a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequest.java index c844f09a83..cf9bb78e6d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequest.java @@ -9,6 +9,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.Set; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -16,25 +17,132 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.utils.StringUtils; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpSchema; import lombok.Getter; +import lombok.extern.log4j.Log4j2; +@Log4j2 public class MLMcpServerRequest extends ActionRequest { + + private static final int MAX_ID_LENGTH = 1000; + private static final int MAX_REQUEST_SIZE = 10 * 1024 * 1024; + private static final Set VALID_METHODS = Set + .of( + McpSchema.METHOD_INITIALIZE, + McpSchema.METHOD_NOTIFICATION_INITIALIZED, + McpSchema.METHOD_PING, + McpSchema.METHOD_NOTIFICATION_PROGRESS, + McpSchema.METHOD_TOOLS_LIST, + McpSchema.METHOD_TOOLS_CALL, + McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, + McpSchema.METHOD_RESOURCES_LIST, + McpSchema.METHOD_RESOURCES_READ, + McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, + McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, + McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, + McpSchema.METHOD_RESOURCES_SUBSCRIBE, + McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, + McpSchema.METHOD_PROMPT_LIST, + McpSchema.METHOD_PROMPT_GET, + McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, + McpSchema.METHOD_COMPLETION_COMPLETE, + McpSchema.METHOD_LOGGING_SET_LEVEL, + McpSchema.METHOD_NOTIFICATION_MESSAGE, + McpSchema.METHOD_ROOTS_LIST, + McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, + McpSchema.METHOD_ELICITATION_CREATE + ); + @Getter - private String requestBody; + private McpSchema.JSONRPCMessage message; public MLMcpServerRequest(StreamInput in) throws IOException { super(in); - this.requestBody = in.readString(); + validateAndParseRequest(in.readString()); } public MLMcpServerRequest(String requestBody) { - this.requestBody = requestBody; + validateAndParseRequest(requestBody); + } + + private void validateAndParseRequest(String requestBody) { + if (requestBody == null || requestBody.isEmpty()) { + throw new IllegalArgumentException("Request body cannot be null or empty"); + } + if (requestBody.length() > MAX_REQUEST_SIZE) { + throw new IllegalArgumentException("Request body exceeds maximum size of " + MAX_REQUEST_SIZE + " bytes"); + } + + try { + message = McpSchema.deserializeJsonRpcMessage(new ObjectMapper(), requestBody); + } catch (Exception e) { + log.error("Parse error: " + e.getMessage(), e); + throw new IllegalArgumentException("Failed to parse JSON-RPC message: " + e.getMessage(), e); + } + + validateMessage(); + } + + private void validateMessage() { + if (!McpSchema.JSONRPC_VERSION.equals(message.jsonrpc())) { + throw new IllegalArgumentException("Invalid jsonrpc version. Expected '2.0' but got '" + message.jsonrpc() + "'"); + } + + if (message instanceof McpSchema.JSONRPCRequest request) { + validateRequestId(request.id()); + validateMethod(request.method()); + } else if (message instanceof McpSchema.JSONRPCNotification notification) { + validateMethod(notification.method()); + } else if (message instanceof McpSchema.JSONRPCResponse) { + throw new IllegalArgumentException("JSON-RPC responses are not accepted as incoming messages"); + } else { + throw new IllegalArgumentException("Unknown JSON-RPC message type: " + message.getClass().getName()); + } + } + + private void validateRequestId(Object id) { + if (id == null) { + throw new IllegalArgumentException("Request ID cannot be null"); + } + if (!(id instanceof String || id instanceof Integer || id instanceof Long)) { + throw new IllegalArgumentException("Request ID must be a string or integer, but got: " + id.getClass().getSimpleName()); + } + if (id instanceof String) { + String idStr = (String) id; + if (idStr.length() > MAX_ID_LENGTH) { + throw new IllegalArgumentException("Request ID exceeds maximum length of " + MAX_ID_LENGTH + " characters"); + } + if (!StringUtils.matchesSafePattern(idStr)) { + throw new IllegalArgumentException("Request ID " + StringUtils.SAFE_INPUT_DESCRIPTION); + } + } + } + + private void validateMethod(String method) { + if (method == null || method.isEmpty()) { + throw new IllegalArgumentException("Method cannot be null or empty"); + } + if (!VALID_METHODS.contains(method)) { + throw new IllegalArgumentException("Invalid MCP method: '" + method + "'. Must be one of the supported MCP methods."); + } } public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(requestBody); + // Serialize the message back to JSON string + ObjectMapper objectMapper = new ObjectMapper(); + try { + String jsonString = objectMapper.writeValueAsString(message); + out.writeString(jsonString); + } catch (JsonProcessingException e) { + throw new IOException("Failed to serialize JSON-RPC message", e); + } } public static MLMcpServerRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequestTest.java index 67daf76db1..2ce1e3f633 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/mcpserver/requests/server/MLMcpServerRequestTest.java @@ -9,108 +9,198 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import java.io.IOException; -import java.io.UncheckedIOException; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; -import org.opensearch.action.ActionRequest; +import org.junit.rules.ExpectedException; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; + +import io.modelcontextprotocol.spec.McpSchema; public class MLMcpServerRequestTest { - private MLMcpServerRequest mlMcpServerRequest; - private String testRequestBody; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private String validRequestWithStringId; + private String validRequestWithIntegerId; + private String validNotification; @Before public void setUp() { - testRequestBody = "{\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{}}}"; - mlMcpServerRequest = new MLMcpServerRequest(testRequestBody); + validRequestWithStringId = """ + { + "jsonrpc": "2.0", + "id": "test-123", + "method": "tools/list", + "params": {} + } + """; + + validRequestWithIntegerId = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} + } + """; + + validNotification = """ + { + "jsonrpc": "2.0", + "method": "ping", + "params": {} + } + """; } @Test - public void testConstructor_withRequestBody() { - assertNotNull(mlMcpServerRequest); - assertEquals(testRequestBody, mlMcpServerRequest.getRequestBody()); + public void testConstructor_ValidRequestWithStringId() { + MLMcpServerRequest request = new MLMcpServerRequest(validRequestWithStringId); + + assertNotNull(request); + assertNotNull(request.getMessage()); + assertTrue(request.getMessage() instanceof McpSchema.JSONRPCRequest); + assertEquals("test-123", ((McpSchema.JSONRPCRequest) request.getMessage()).id()); } @Test - public void testConstructor_withStreamInput() throws IOException { - BytesStreamOutput output = new BytesStreamOutput(); - mlMcpServerRequest.writeTo(output); + public void testConstructor_ValidRequestWithIntegerId() { + MLMcpServerRequest request = new MLMcpServerRequest(validRequestWithIntegerId); - StreamInput input = output.bytes().streamInput(); - MLMcpServerRequest parsedRequest = new MLMcpServerRequest(input); - - assertNotNull(parsedRequest); - assertEquals(testRequestBody, parsedRequest.getRequestBody()); + assertNotNull(request); + assertEquals(1, ((McpSchema.JSONRPCRequest) request.getMessage()).id()); } @Test - public void testWriteTo() throws IOException { - BytesStreamOutput output = new BytesStreamOutput(); - mlMcpServerRequest.writeTo(output); + public void testConstructor_ValidNotification() { + MLMcpServerRequest request = new MLMcpServerRequest(validNotification); - StreamInput input = output.bytes().streamInput(); - MLMcpServerRequest parsedRequest = new MLMcpServerRequest(input); - - assertEquals(mlMcpServerRequest.getRequestBody(), parsedRequest.getRequestBody()); + assertNotNull(request); + assertTrue(request.getMessage() instanceof McpSchema.JSONRPCNotification); } @Test - public void testValidate() { - ActionRequestValidationException validationException = mlMcpServerRequest.validate(); - assertNull(validationException); + public void testConstructor_InvalidJsonRpcVersion() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid jsonrpc version"); + + String invalidRequest = """ + { + "jsonrpc": "1.0", + "id": 1, + "method": "ping" + } + """; + new MLMcpServerRequest(invalidRequest); } @Test - public void testFromActionRequest_withMLMcpServerRequest() { - MLMcpServerRequest result = MLMcpServerRequest.fromActionRequest(mlMcpServerRequest); - assertSame(mlMcpServerRequest, result); + public void testConstructor_InvalidMethod() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid MCP method"); + + String invalidRequest = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "invalid_method" + } + """; + new MLMcpServerRequest(invalidRequest); } @Test - public void testFromActionRequest_withOtherActionRequest() throws IOException { - MLMcpServerRequest mlMcpServerRequest = new MLMcpServerRequest(testRequestBody); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; + public void testConstructor_InvalidIdWithSpecialCharacters() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("can only contain"); + + String invalidRequest = """ + { + "jsonrpc": "2.0", + "id": "", + "method": "ping" } + """; + new MLMcpServerRequest(invalidRequest); + } - @Override - public void writeTo(StreamOutput out) throws IOException { - mlMcpServerRequest.writeTo(out); + @Test + public void testConstructor_ResponseRejected() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("JSON-RPC responses are not accepted"); + + String response = """ + { + "jsonrpc": "2.0", + "id": 1, + "result": {} } - }; - MLMcpServerRequest result = MLMcpServerRequest.fromActionRequest(actionRequest); - assertNotNull(result); - assertEquals(testRequestBody, result.getRequestBody()); + """; + new MLMcpServerRequest(response); } - @Test(expected = UncheckedIOException.class) - public void testFromActionRequest_withIOException() { - ActionRequest failingRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } + @Test + public void testConstructor_EmptyBody() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("cannot be null or empty"); - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("Test IOException"); - } - }; + new MLMcpServerRequest(""); + } - MLMcpServerRequest.fromActionRequest(failingRequest); + @Test + public void testConstructor_NullBody() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("cannot be null or empty"); + + String nullBody = null; + new MLMcpServerRequest(nullBody); + } + + @Test + public void testConstructor_InvalidJson() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Failed to parse JSON-RPC message"); + + new MLMcpServerRequest("invalid json"); } @Test - public void testGetRequestBody() { - assertEquals(testRequestBody, mlMcpServerRequest.getRequestBody()); + public void testWriteTo_Success() throws IOException { + MLMcpServerRequest original = new MLMcpServerRequest(validRequestWithIntegerId); + + BytesStreamOutput output = new BytesStreamOutput(); + original.writeTo(output); + + StreamInput input = output.bytes().streamInput(); + MLMcpServerRequest deserialized = new MLMcpServerRequest(input); + + assertNotNull(deserialized); + assertEquals(original.getMessage().jsonrpc(), deserialized.getMessage().jsonrpc()); + } + + @Test + public void testFromActionRequest_SameInstance() { + MLMcpServerRequest original = new MLMcpServerRequest(validRequestWithIntegerId); + + MLMcpServerRequest result = MLMcpServerRequest.fromActionRequest(original); + + assertSame(original, result); + } + + @Test + public void testValidate_Success() { + MLMcpServerRequest request = new MLMcpServerRequest(validRequestWithIntegerId); + + ActionRequestValidationException validation = request.validate(); + + assertNull(validation); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpServerAction.java b/plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpServerAction.java index ff6a0ba593..6ee6456e5b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpServerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/mcpserver/TransportMcpServerAction.java @@ -8,7 +8,6 @@ import static org.opensearch.ml.common.CommonValue.ERROR_CODE_FIELD; import static org.opensearch.ml.common.CommonValue.ID_FIELD; import static org.opensearch.ml.common.CommonValue.JSON_RPC_INTERNAL_ERROR; -import static org.opensearch.ml.common.CommonValue.JSON_RPC_PARSE_ERROR; import static org.opensearch.ml.common.CommonValue.JSON_RPC_SERVER_NOT_READY_ERROR; import static org.opensearch.ml.common.CommonValue.MESSAGE_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE; @@ -76,14 +75,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { channel.sendResponse(new BytesRestResponse(RestStatus.METHOD_NOT_ALLOWED, "", BytesArray.EMPTY)); }; } - return channel -> { - try { - if (request.content() == null) { - sendErrorResponse(channel, null, JSON_RPC_PARSE_ERROR, "Parse error: empty body"); - return; - } - final String requestBody = request.content().utf8ToString(); - if (requestBody == null || requestBody.isBlank()) { - sendErrorResponse(channel, null, JSON_RPC_PARSE_ERROR, "Parse error: empty body"); - return; - } + final String requestBody = request.content() != null ? request.content().utf8ToString() : null; + MLMcpServerRequest mcpRequest = new MLMcpServerRequest(requestBody); - // Create request and call transport action - MLMcpServerRequest mcpRequest = new MLMcpServerRequest(requestBody); - - client.execute(MLMcpServerAction.INSTANCE, mcpRequest, new ActionListener() { - @Override - public void onResponse(MLMcpServerResponse response) { - try { - if (response.getError() != null) { - // Handle error response - Map errorMap = response.getError(); - Object id = errorMap.get(ID_FIELD); - int code = (Integer) errorMap.get(ERROR_CODE_FIELD); - String message = (String) errorMap.get(MESSAGE_FIELD); - sendErrorResponse(channel, id, code, message); - } else if (response.getMcpResponse() != null) { - channel.sendResponse(new BytesRestResponse(RestStatus.OK, "application/json", response.getMcpResponse())); - } else { - channel.sendResponse(new BytesRestResponse(RestStatus.ACCEPTED, "", BytesArray.EMPTY)); - } - } catch (Exception e) { - log.error("Failed to send response", e); - sendErrorResponse(channel, null, JSON_RPC_INTERNAL_ERROR, "Failed to send response"); + return channel -> { + client.execute(MLMcpServerAction.INSTANCE, mcpRequest, new ActionListener() { + @Override + public void onResponse(MLMcpServerResponse response) { + try { + if (response.getError() != null) { + // Handle error response + Map errorMap = response.getError(); + Object id = errorMap.get(ID_FIELD); + int code = (Integer) errorMap.get(ERROR_CODE_FIELD); + String message = (String) errorMap.get(MESSAGE_FIELD); + sendErrorResponse(channel, id, code, message); + } else if (response.getMcpResponse() != null) { + channel.sendResponse(new BytesRestResponse(RestStatus.OK, "application/json", response.getMcpResponse())); + } else { + channel.sendResponse(new BytesRestResponse(RestStatus.ACCEPTED, "", BytesArray.EMPTY)); } + } catch (Exception e) { + log.error("Failed to send response", e); + sendErrorResponse(channel, null, JSON_RPC_INTERNAL_ERROR, "Failed to send response"); } + } - @Override - public void onFailure(Exception e) { - log.error("Failed to handle MCP request", e); - sendErrorResponse(channel, null, JSON_RPC_INTERNAL_ERROR, "Internal server error: " + e.getMessage()); - } - }); - - } catch (Exception e) { - log.error("Failed to handle MCP request", e); - sendErrorResponse(channel, null, JSON_RPC_INTERNAL_ERROR, "Internal server error"); - } + @Override + public void onFailure(Exception e) { + log.error("Failed to handle MCP request", e); + sendErrorResponse(channel, null, JSON_RPC_INTERNAL_ERROR, "Internal server error: " + e.getMessage()); + } + }); }; } diff --git a/plugin/src/test/java/org/opensearch/ml/action/mcpserver/TransportMcpServerActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/mcpserver/TransportMcpServerActionTests.java index a579923025..15b274d603 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/mcpserver/TransportMcpServerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/mcpserver/TransportMcpServerActionTests.java @@ -12,7 +12,6 @@ import static org.opensearch.ml.common.CommonValue.ERROR_CODE_FIELD; import static org.opensearch.ml.common.CommonValue.ID_FIELD; import static org.opensearch.ml.common.CommonValue.JSON_RPC_INTERNAL_ERROR; -import static org.opensearch.ml.common.CommonValue.JSON_RPC_PARSE_ERROR; import static org.opensearch.ml.common.CommonValue.JSON_RPC_SERVER_NOT_READY_ERROR; import static org.opensearch.ml.common.CommonValue.MESSAGE_FIELD; @@ -68,7 +67,7 @@ public void setUp() throws Exception { public void test_doExecute_mcpServerDisabled() { when(mlFeatureEnabledSetting.isMcpServerEnabled()).thenReturn(false); - MLMcpServerRequest request = new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"test\"}"); + MLMcpServerRequest request = new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"ping\",\"params\":{}}"); action.doExecute(task, request, listener); @@ -79,7 +78,7 @@ public void test_doExecute_mcpServerDisabled() { public void test_doExecute_transportProviderNotReady() { when(mlFeatureEnabledSetting.isMcpServerEnabled()).thenReturn(true); when(mcpStatelessServerHolder.getMcpStatelessServerTransportProvider()).thenReturn(null); - MLMcpServerRequest request = new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"test\"}"); + MLMcpServerRequest request = new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{}}"); action.doExecute(task, request, listener); @@ -95,21 +94,31 @@ public void test_doExecute_transportProviderNotReady() { } public void test_doExecute_invalidJsonRpcMessage() { - when(mlFeatureEnabledSetting.isMcpServerEnabled()).thenReturn(true); - when(mcpStatelessServerHolder.getMcpStatelessServerTransportProvider()).thenReturn(transportProvider); - MLMcpServerRequest request = new MLMcpServerRequest("invalid json"); + // Validation now happens during MLMcpServerRequest construction + // This test verifies that invalid JSON throws IllegalArgumentException + expectThrows(IllegalArgumentException.class, () -> { new MLMcpServerRequest("invalid json"); }); + } - action.doExecute(task, request, listener); + public void test_doExecute_invalidJsonRpcVersion() { + // Validation now happens during MLMcpServerRequest construction + expectThrows( + IllegalArgumentException.class, + () -> { new MLMcpServerRequest("{\"jsonrpc\":\"1.0\",\"id\":1,\"method\":\"ping\"}"); } + ); + } - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLMcpServerResponse.class); - verify(listener).onResponse(responseCaptor.capture()); - - MLMcpServerResponse response = responseCaptor.getValue(); - assertFalse(response.getAcknowledgedResponse()); - assertNull(response.getMcpResponse()); - assertNotNull(response.getError()); - assertEquals(JSON_RPC_PARSE_ERROR, response.getError().get(ERROR_CODE_FIELD)); - assertTrue(response.getError().get(MESSAGE_FIELD).toString().contains("Parse error")); + public void test_doExecute_invalidMethod() { + // Validation now happens during MLMcpServerRequest construction + expectThrows(IllegalArgumentException.class, () -> { + new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"invalid_method\"}"); + }); + } + + public void test_doExecute_invalidIdWithSpecialCharacters() { + // Validation now happens during MLMcpServerRequest construction + expectThrows(IllegalArgumentException.class, () -> { + new MLMcpServerRequest("{\"jsonrpc\":\"2.0\",\"id\":\"\",\"method\":\"ping\"}"); + }); } public void test_doExecute_jsonRpcNotification() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/mcpserver/RestMcpServerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/mcpserver/RestMcpServerActionTests.java index 8105cab756..bc28d77ace 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/mcpserver/RestMcpServerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/mcpserver/RestMcpServerActionTests.java @@ -150,8 +150,10 @@ public void test_prepareRequest_emptyBody() throws Exception { .withContent(new BytesArray(""), null) .build(); - executeRestChannelConsumer(request); - verifyErrorResponse("Parse error: empty body"); + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("Request body cannot be null or empty")); } @Test @@ -162,54 +164,66 @@ public void test_prepareRequest_nullContent() throws Exception { .withContent(null, null) .build(); - executeRestChannelConsumer(request); - verifyErrorResponse("Parse error: empty body"); + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("Request body cannot be null or empty")); } @Test public void test_prepareRequest_invalidJson() throws Exception { - // Mock the transport action to return an error response - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - Map errorMap = new HashMap<>(); - errorMap.put(ID_FIELD, null); - errorMap.put(ERROR_CODE_FIELD, -32700); - errorMap.put(MESSAGE_FIELD, "Parse error: invalid json"); - listener.onResponse(new MLMcpServerResponse(false, null, errorMap)); - return null; - }).when(client).execute(eq(MLMcpServerAction.INSTANCE), any(), any()); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) .withContent(new BytesArray("invalid json"), null) .build(); - executeRestChannelConsumer(request); - verifyErrorResponse("Parse error: invalid json"); + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("Failed to parse JSON-RPC message")); } @Test - public void test_prepareRequest_malformedJsonRpc() throws Exception { - // Mock the transport action to return an error response - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - Map errorMap = new HashMap<>(); - errorMap.put(ID_FIELD, 1); - errorMap.put(ERROR_CODE_FIELD, -32700); - errorMap.put(MESSAGE_FIELD, "Parse error: malformed JSON-RPC"); - listener.onResponse(new MLMcpServerResponse(false, null, errorMap)); - return null; - }).when(client).execute(eq(MLMcpServerAction.INSTANCE), any(), any()); + public void test_prepareRequest_invalidJsonRpcVersion() throws Exception { + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(RestRequest.Method.POST) + .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) + .withContent(new BytesArray("{\"jsonrpc\":\"1.0\",\"id\":1,\"method\":\"ping\"}"), null) + .build(); + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("Invalid jsonrpc version")); + } + + @Test + public void test_prepareRequest_invalidMethod() throws Exception { RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) - .withContent(new BytesArray("{\"jsonrpc\":\"1.0\",\"id\":1}"), null) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"invalid_method\"}"), null) .build(); - executeRestChannelConsumer(request); - verifyErrorResponse("Parse error: malformed JSON-RPC"); + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("Invalid MCP method")); + } + + @Test + public void test_prepareRequest_invalidIdWithSpecialCharacters() throws Exception { + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(RestRequest.Method.POST) + .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":\"\",\"method\":\"ping\"}"), null) + .build(); + + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + restMCPStatelessStreamingAction.prepareRequest(request, client); + }); + assertTrue(exception.getMessage().contains("can only contain")); } @Test @@ -224,7 +238,7 @@ public void test_prepareRequest_notificationMessage() throws Exception { RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) - .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}"), null) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"params\":{}}"), null) .build(); executeRestChannelConsumer(request); @@ -247,7 +261,7 @@ public void test_prepareRequest_validRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) - .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"test\",\"params\":{}}"), null) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/list\",\"params\":{}}"), null) .build(); executeRestChannelConsumer(request); @@ -274,7 +288,7 @@ public void test_prepareRequest_transportProviderNotReady() throws Exception { RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) - .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"test\",\"params\":{}}"), null) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{}}"), null) .build(); executeRestChannelConsumer(request); @@ -293,7 +307,7 @@ public void test_prepareRequest_transportActionFailure() throws Exception { RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(RestRequest.Method.POST) .withPath(RestMcpServerAction.MCP_SERVER_ENDPOINT) - .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"test\",\"params\":{}}"), null) + .withContent(new BytesArray("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"ping\",\"params\":{}}"), null) .build(); executeRestChannelConsumer(request);