From 285b6e20888d506046d09d8ee521da1734834a28 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 18 Nov 2025 11:48:07 -0800 Subject: [PATCH 1/5] Add support for custom named connector actions; add PUT/DELETE action (#4430) * Add support for custom named connector actions; add PUT/DELETE action Signed-off-by: Yaliang Wu * add more comment Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 2 + .../common/connector/AbstractConnector.java | 7 +- .../ml/common/connector/ConnectorAction.java | 27 ++++ .../ml/common/connector/HttpConnector.java | 11 +- .../ml/common/utils/StringUtils.java | 34 ++++ .../ml/common/connector/AwsConnectorTest.java | 2 + .../common/connector/ConnectorActionTest.java | 42 ++++- .../common/connector/HttpConnectorTest.java | 153 ++++++++++++++++++ .../MLCreateConnectorInputTests.java | 2 + .../MLCreateConnectorRequestTests.java | 2 + .../ml/common/utils/StringUtilsTest.java | 89 ++++++++++ .../remote/AwsConnectorExecutor.java | 10 +- .../remote/HttpJsonConnectorExecutor.java | 10 +- .../ExecuteConnectorTransportAction.java | 16 +- .../ExecuteConnectorTransportActionTests.java | 8 + 15 files changed, 399 insertions(+), 16 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 351171ede6..5570b14ee1 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -97,6 +97,7 @@ public class CommonValue { public static final Version VERSION_3_1_0 = Version.fromString("3.1.0"); public static final Version VERSION_3_2_0 = Version.fromString("3.2.0"); public static final Version VERSION_3_3_0 = Version.fromString("3.3.0"); + public static final Version VERSION_3_4_0 = Version.fromString("3.4.0"); // Connector Constants public static final String NAME_FIELD = "name"; @@ -113,6 +114,7 @@ public class CommonValue { public static final String CLIENT_CONFIG_FIELD = "client_config"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; + public static final String CONNECTOR_ACTION_FIELD = "connector_action"; // MCP Constants public static final String MCP_TOOL_NAME_FIELD = "name"; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 9a035230a0..05f2d3781b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -121,8 +121,11 @@ public void parseResponse(T response, List modelTensors, boolea @Override public Optional findAction(String action) { - if (actions != null) { - return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + if (actions != null && action != null) { + if (ConnectorAction.ActionType.isValidAction(action)) { + return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + } + return actions.stream().filter(a -> action.equals(a.getName())).findFirst(); } return Optional.empty(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index c82f489296..8ff29a44f2 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -6,8 +6,10 @@ package org.opensearch.ml.common.connector; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.VERSION_3_4_0; import java.io.IOException; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -33,6 +35,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final String ACTION_TYPE_FIELD = "action_type"; + public static final String NAME_FIELD = "name"; public static final String METHOD_FIELD = "method"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; @@ -52,6 +55,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { private static final Logger logger = LogManager.getLogger(ConnectorAction.class); private ActionType actionType; + private String name; private String method; private String url; private Map headers; @@ -62,6 +66,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { @Builder(toBuilder = true) public ConnectorAction( ActionType actionType, + String name, String method, String url, Map headers, @@ -78,7 +83,15 @@ public ConnectorAction( if (method == null) { throw new IllegalArgumentException("method can't be null"); } + // The 'name' field is an optional identifier for this specific action within a connector. + // It allows running a specific action by name when a connector has multiple actions of the same type. + // We validate that 'name' is not an action type (PREDICT, EXECUTE, etc.) to avoid ambiguity + // when resolving actions. + if (name != null && ActionType.isValidAction(name)) { + throw new IllegalArgumentException("name can't be one of action type " + Arrays.toString(ActionType.values())); + } this.actionType = actionType; + this.name = name; this.method = method; this.url = url; this.headers = headers; @@ -97,6 +110,9 @@ public ConnectorAction(StreamInput input) throws IOException { this.requestBody = input.readOptionalString(); this.preProcessFunction = input.readOptionalString(); this.postProcessFunction = input.readOptionalString(); + if (input.getVersion().onOrAfter(VERSION_3_4_0)) { + this.name = input.readOptionalString(); + } } @Override @@ -113,6 +129,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(requestBody); out.writeOptionalString(preProcessFunction); out.writeOptionalString(postProcessFunction); + if (out.getVersion().onOrAfter(VERSION_3_4_0)) { + out.writeOptionalString(name); + } } @Override @@ -139,6 +158,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (postProcessFunction != null) { builder.field(ACTION_POST_PROCESS_FUNCTION, postProcessFunction); } + if (name != null) { + builder.field(NAME_FIELD, name); + } return builder.endObject(); } @@ -149,6 +171,7 @@ public static ConnectorAction fromStream(StreamInput in) throws IOException { public static ConnectorAction parse(XContentParser parser) throws IOException { ActionType actionType = null; + String name = null; String method = null; String url = null; Map headers = null; @@ -165,6 +188,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { case ACTION_TYPE_FIELD: actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; + case NAME_FIELD: + name = parser.text(); + break; case METHOD_FIELD: method = parser.text(); break; @@ -191,6 +217,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { return ConnectorAction .builder() .actionType(actionType) + .name(name) .method(method) .url(url) .headers(headers) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 53f66ce384..c93a8e6abb 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson; import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import java.io.IOException; @@ -358,12 +359,14 @@ public T createPayload(String action, Map parameters) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - if (!isJson(payload)) { + if (!isJsonOrNdjson(payload)) { throw new IllegalArgumentException("Invalid payload: " + payload); } else if (neededStreamParameterInPayload(parameters)) { - JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); - jsonObject.addProperty("stream", true); - payload = jsonObject.toString(); + if (isJson(payload)) { + JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); + jsonObject.addProperty("stream", true); + payload = jsonObject.toString(); + } } return (T) payload; } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index a3f1a3b416..039b1c2bd1 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -136,6 +136,40 @@ public static boolean isJson(String json) { } } + /** + * Checks if the given string is valid JSON or NDJSON (newline-delimited JSON). + * NDJSON is commonly used for bulk operations in OpenSearch where each line is a separate JSON object. + * + * @param json the string to validate + * @return true if the string is valid JSON or NDJSON, false otherwise + */ + public static boolean isJsonOrNdjson(String json) { + if (json == null || json.isBlank()) { + return false; + } + + // First check if it's regular JSON + if (isJson(json)) { + return true; + } + + // Check if it's NDJSON (newline-delimited JSON) + String[] lines = json.split("\\r?\\n"); + if (lines.length == 0) { + return false; + } + + // Each non-empty line must be valid JSON + for (String line : lines) { + String trimmedLine = line.trim(); + if (!trimmedLine.isEmpty() && !isJson(trimmedLine)) { + return false; + } + } + + return true; + } + /** * Ensures that a string is properly JSON escaped. * diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index 2b679b8bbe..63033a2316 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -213,6 +213,7 @@ private AwsConnector createAwsConnector() { private AwsConnector createAwsConnector(Map parameters, Map credential, String url) { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; Map headers = new HashMap<>(); headers.put("api_key", "${credential.key}"); @@ -222,6 +223,7 @@ private AwsConnector createAwsConnector(Map parameters, Map new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(null, null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("action type can't be null", exception.getMessage()); @@ -109,7 +109,7 @@ public void constructor_NullActionType() { public void constructor_NullUrl() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) ); assertEquals("url can't be null", exception.getMessage()); } @@ -118,14 +118,23 @@ public void constructor_NullUrl() { public void constructor_NullMethod() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, null, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("method can't be null", exception.getMessage()); } @Test public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess() { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, OPENAI_URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + null, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + null, + null + ); action.validatePrePostProcessFunctions(Map.of()); assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); } @@ -134,6 +143,7 @@ public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess public void testValidatePrePostProcessFunctionsWithExternalServers() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, null, @@ -151,6 +161,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces "\"\\n StringBuilder builder = new StringBuilder();\\n builder.append(\\\"\\\\\\\"\\\");\\n String first = params.text_docs[0];\\n builder.append(first);\\n builder.append(\\\"\\\\\\\"\\\");\\n def parameters = \\\"{\\\" +\\\"\\\\\\\"text_inputs\\\\\\\":\\\" + builder + \\\"}\\\";\\n return \\\"{\\\" +\\\"\\\\\\\"parameters\\\\\\\":\\\" + parameters + \\\"}\\\";\""; ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -166,6 +177,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, "https://${parameters.endpoint}/v1/chat/completions", null, @@ -181,6 +193,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -206,6 +219,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -231,6 +245,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -243,6 +258,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -255,6 +271,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -270,6 +287,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -295,6 +313,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -320,6 +339,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -332,6 +352,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -344,6 +365,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -359,6 +381,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -384,6 +407,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -409,6 +433,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -421,6 +446,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -436,6 +462,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -463,6 +490,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -488,7 +516,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil @Test public void writeTo_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); @@ -504,6 +532,7 @@ public void writeTo() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, @@ -519,7 +548,7 @@ public void writeTo() throws IOException { @Test public void toXContent_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -540,6 +569,7 @@ public void toXContent() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 1038006f2c..9d5d69dac3 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.BiFunction; import org.junit.Assert; @@ -379,6 +380,7 @@ public static HttpConnector createHttpConnector() { public static HttpConnector createHttpConnectorWithRequestBody(String requestBody) { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -388,6 +390,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, @@ -531,4 +534,154 @@ public void testParseResponse_NonStringNonMapResponse() throws IOException { Assert.assertEquals(42, modelTensors.get(0).getDataAsMap().get("response")); } + @Test + public void testFindAction_WithValidActionType() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("PREDICT"); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithValidActionTypeCaseInsensitive() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("predict"); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithCustomActionName() { + String customActionName = "custom_action"; + ConnectorAction customAction = new ConnectorAction( + PREDICT, + customActionName, + "POST", + "https://custom.com", + null, + "{\"input\": \"test\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(customAction)) + .build(); + + Optional action = connector.findAction(customActionName); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(customActionName, action.get().getName()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithNullAction() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction(null); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_WithInvalidActionType() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("INVALID_ACTION"); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_WithNullActions() { + HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").actions(null).build(); + Optional action = connector.findAction("PREDICT"); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_CustomNameTakesPrecedenceOverActionType() { + String customActionName = "my_predict"; + ConnectorAction action1 = new ConnectorAction( + PREDICT, + null, + "POST", + "https://test1.com", + null, + "{\"input\": \"test1\"}", + null, + null + ); + ConnectorAction action2 = new ConnectorAction( + ConnectorAction.ActionType.EXECUTE, + customActionName, + "POST", + "https://test2.com", + null, + "{\"input\": \"test2\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(action1, action2)) + .build(); + + // When searching by valid action type, should find by action type first + Optional foundByType = connector.findAction("PREDICT"); + Assert.assertTrue(foundByType.isPresent()); + Assert.assertEquals(PREDICT, foundByType.get().getActionType()); + Assert.assertEquals("https://test1.com", foundByType.get().getUrl()); + + // When searching by custom name, should find by name + Optional foundByName = connector.findAction(customActionName); + Assert.assertTrue(foundByName.isPresent()); + Assert.assertEquals(customActionName, foundByName.get().getName()); + Assert.assertEquals("https://test2.com", foundByName.get().getUrl()); + } + + @Test + public void testFindAction_MultipleActionsWithSameType() { + ConnectorAction action1 = new ConnectorAction( + PREDICT, + "predict_action_1", + "POST", + "https://test1.com", + null, + "{\"input\": \"test1\"}", + null, + null + ); + ConnectorAction action2 = new ConnectorAction( + PREDICT, + "predict_action_2", + "POST", + "https://test2.com", + null, + "{\"input\": \"test2\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(action1, action2)) + .build(); + + // Should return the first matching action when searching by type + Optional foundByType = connector.findAction("PREDICT"); + Assert.assertTrue(foundByType.isPresent()); + Assert.assertEquals("predict_action_1", foundByType.get().getName()); + + // Should find specific action by custom name + Optional foundByName = connector.findAction("predict_action_2"); + Assert.assertTrue(foundByName.isPresent()); + Assert.assertEquals("predict_action_2", foundByName.get().getName()); + Assert.assertEquals("https://test2.com", foundByName.get().getUrl()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index a7df00618a..b84caceb59 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -73,6 +73,7 @@ public class MLCreateConnectorInputTests { @Before public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -82,6 +83,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index b4f7629689..9dbb083e76 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -39,6 +39,7 @@ public class MLCreateConnectorRequestTests { @Before public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -48,6 +49,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index e81ccc54a3..4d6e2b7ecd 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -77,6 +77,95 @@ public void isJson_False() { assertFalse(StringUtils.isJson("[abc\n123]")); } + @Test + public void isJsonOrNdjson_NullInput() { + assertFalse(StringUtils.isJsonOrNdjson(null)); + } + + @Test + public void isJsonOrNdjson_BlankInput() { + assertFalse(StringUtils.isJsonOrNdjson("")); + assertFalse(StringUtils.isJsonOrNdjson(" ")); + assertFalse(StringUtils.isJsonOrNdjson("\n")); + assertFalse(StringUtils.isJsonOrNdjson("\t")); + } + + @Test + public void isJsonOrNdjson_ValidJson() { + // Valid JSON objects should return true + assertTrue(StringUtils.isJsonOrNdjson("{}")); + assertTrue(StringUtils.isJsonOrNdjson("[]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": 123}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2, 3]")); + assertTrue(StringUtils.isJsonOrNdjson("[\"a\", \"b\"]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value\", \"key2\": 123}")); + } + + @Test + public void isJsonOrNdjson_ValidNdjson() { + // Valid NDJSON (newline-delimited JSON) should return true + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"index\": {}}\n{\"field\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2, 3]\n[4, 5, 6]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2}\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_ValidNdjsonWithCarriageReturn() { + // NDJSON with \r\n line endings should return true + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\r\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\r\n{\"b\": 2}\r\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_NdjsonWithEmptyLines() { + // NDJSON with empty lines should return true (empty lines are ignored) + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n \n{\"b\": 2}")); + assertTrue(StringUtils.isJsonOrNdjson("\n{\"key\": \"value\"}\n")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": \"value\"}\n\n")); + } + + @Test + public void isJsonOrNdjson_InvalidJson() { + // Invalid JSON should return false + assertFalse(StringUtils.isJsonOrNdjson("{")); + assertFalse(StringUtils.isJsonOrNdjson("[")); + assertFalse(StringUtils.isJsonOrNdjson("{\"key\": \"value}")); + assertFalse(StringUtils.isJsonOrNdjson("[1, \"a]")); + assertFalse(StringUtils.isJsonOrNdjson("not json")); + assertFalse(StringUtils.isJsonOrNdjson("123abc")); + } + + @Test + public void isJsonOrNdjson_InvalidNdjson() { + // NDJSON with at least one invalid line should return false + assertFalse(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\ninvalid json")); + assertFalse(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n{\"key2\": \"value2}\n{\"key3\": \"value3\"}")); + assertFalse(StringUtils.isJsonOrNdjson("invalid\n{\"key\": \"value\"}")); + assertFalse(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_MixedValidInvalidLines() { + // Mix of valid and invalid JSON lines should return false + assertFalse(StringUtils.isJsonOrNdjson("{\"valid\": true}\n{invalid}\n{\"also_valid\": true}")); + assertFalse(StringUtils.isJsonOrNdjson("[1, 2, 3]\nplain text\n[4, 5, 6]")); + } + + @Test + public void isJsonOrNdjson_OpenSearchBulkFormat() { + // OpenSearch bulk API format (action/metadata line followed by document) + assertTrue(StringUtils.isJsonOrNdjson("{\"index\": {\"_index\": \"test\"}}\n{\"field\": \"value\"}")); + assertTrue( + StringUtils + .isJsonOrNdjson( + "{\"index\": {\"_index\": \"test\", \"_id\": \"1\"}}\n{\"field1\": \"value1\"}\n{\"index\": {\"_index\": \"test\", \"_id\": \"2\"}}\n{\"field2\": \"value2\"}" + ) + ); + } + @Test public void toUTF8() { String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..f500ae32d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -9,8 +9,10 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; import java.security.AccessController; import java.security.PrivilegedExceptionAction; @@ -110,7 +112,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..45b318bc6c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -8,8 +8,10 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; import java.security.AccessController; import java.security.PrivilegedExceptionAction; @@ -109,7 +111,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java index c1b5db1778..54be390178 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.connector; +import static org.opensearch.ml.common.CommonValue.CONNECTOR_ACTION_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import org.opensearch.ResourceNotFoundException; @@ -18,6 +19,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; @@ -73,14 +75,24 @@ public ExecuteConnectorTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); String connectorId = executeConnectorRequest.getConnectorId(); + if (executeConnectorRequest.getMlInput() == null) { + actionListener.onFailure(new IllegalArgumentException("MLInput cannot be null")); + return; + } + + RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) executeConnectorRequest.getMlInput().getInputDataset(); String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + if (inputDataset.getParameters() != null && inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD) != null) { + connectorAction = inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD); + } if (MLIndicesHandler .doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) { + String finalConnectorAction = connectorAction; ActionListener listener = ActionListener.wrap(connector -> { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. - connector.decrypt(connectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + connector.decrypt(finalConnectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); @@ -89,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .executeAction(finalConnectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { actionListener.onResponse(taskResponse); }, e -> { actionListener.onFailure(e); })); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java index 4383cc0f86..6d401a8a3a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -119,6 +119,10 @@ public void setup() { public void testExecute_NoConnectorIndex() { when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(request.getMlInput()).thenReturn(org.opensearch.ml.common.input.MLInput.builder() + .algorithm(org.opensearch.ml.common.FunctionName.REMOTE) + .inputDataset(new org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet(Map.of(), null)) + .build()); action.doExecute(task, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argCaptor.capture()); @@ -128,6 +132,10 @@ public void testExecute_NoConnectorIndex() { public void testExecute_FailedToGetConnector() { when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); when(metaData.hasIndex(anyString())).thenReturn(true); + when(request.getMlInput()).thenReturn(org.opensearch.ml.common.input.MLInput.builder() + .algorithm(org.opensearch.ml.common.FunctionName.REMOTE) + .inputDataset(new org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet(Map.of(), null)) + .build()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); From f4ac35cee8c40697b34f835e4dc3dd6a5730540b Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Tue, 18 Nov 2025 12:02:14 -0800 Subject: [PATCH 2/5] [3.4 Feature Branch] Introduce hook and context management to OpenSearch Agents (#4432) * add hooks in ml-commons (#4326) Signed-off-by: Xun Zhang * initiate context management api with hook implementation (#4345) * initiate context management api with hook implementation Signed-off-by: Mingshi Liu * apply spotless Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * Add Context Manager to PER (#4379) * add pre_llm hook to per agent Signed-off-by: Mingshi Liu change context management passing from query parameters to payload Signed-off-by: Mingshi Liu pass hook registery into PER Signed-off-by: Mingshi Liu apply spotless Signed-off-by: Mingshi Liu initiate context management api with hook implementation Signed-off-by: Mingshi Liu * add comment Signed-off-by: Mingshi Liu * format Signed-off-by: Mingshi Liu * add validation Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * add inner create context management to agent register api Signed-off-by: Mingshi Liu * add code coverage Signed-off-by: Mingshi Liu * allow context management hook register in during agent execute Signed-off-by: Mingshi Liu * add code coverage Signed-off-by: Mingshi Liu * add more code coverage Signed-off-by: Mingshi Liu * add validation check Signed-off-by: Mingshi Liu * adapt to inplace update for context Signed-off-by: Mingshi Liu * Allow context management inline create in register agent without storing in index (#4403) * allow inline create context management without storing in agent register Signed-off-by: Mingshi Liu * make ML_COMMONS_MULTI_TENANCY_ENABLED default is false Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * Update the POST_TOOL hook emit saving to agentic memory (#4408) * Fix POST_TOOL hook interaction updates and add tenant ID support Signed-off-by: Mingshi Liu - Fix POST_TOOL hook to return full ContextManagerContext like PRE_LLM hook - Update MLChatAgentRunner to properly handle interaction updates from POST_TOOL hook - Ensure interactions list and tmpParameters.INTERACTIONS stay synchronized - Add tenant ID support to MLPredictionTaskRequest in ModelGuardrail and SummarizationManager Signed-off-by: Mingshi Liu * fix error message escaping Signed-off-by: Mingshi Liu * consolicate post_hook logic Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * bump supported version to 3.4 Signed-off-by: Mingshi Liu --------- Signed-off-by: Xun Zhang Signed-off-by: Mingshi Liu Co-authored-by: Xun Zhang --- .../org/opensearch/ml/common/CommonValue.java | 2 + .../opensearch/ml/common/agent/MLAgent.java | 95 + .../common/contextmanager/ActivationRule.java | 26 + .../contextmanager/ActivationRuleFactory.java | 148 ++ .../CharacterBasedTokenCounter.java | 89 + .../ContextManagementTemplate.java | 260 +++ .../common/contextmanager/ContextManager.java | 42 + .../contextmanager/ContextManagerConfig.java | 127 ++ .../contextmanager/ContextManagerContext.java | 180 ++ .../ContextManagerHookProvider.java | 199 ++ .../MessageCountExceedRule.java | 34 + .../common/contextmanager/TokenCounter.java | 45 + .../contextmanager/TokensExceedRule.java | 34 + .../common/contextmanager/package-info.java | 21 + .../common/hooks/EnhancedPostToolEvent.java | 46 + .../ml/common/hooks/HookCallback.java | 23 + .../opensearch/ml/common/hooks/HookEvent.java | 33 + .../ml/common/hooks/HookProvider.java | 20 + .../ml/common/hooks/HookRegistry.java | 93 + .../ml/common/hooks/PostMemoryEvent.java | 50 + .../ml/common/hooks/PostToolEvent.java | 46 + .../ml/common/hooks/PreInvocationEvent.java | 23 + .../ml/common/hooks/PreLLMEvent.java | 37 + .../input/execute/agent/AgentMLInput.java | 17 +- .../ml/common/model/ModelGuardrail.java | 6 +- .../agent/MLRegisterAgentRequest.java | 51 + ...CreateContextManagementTemplateAction.java | 17 + ...reateContextManagementTemplateRequest.java | 90 + ...eateContextManagementTemplateResponse.java | 71 + ...DeleteContextManagementTemplateAction.java | 17 + ...eleteContextManagementTemplateRequest.java | 74 + ...leteContextManagementTemplateResponse.java | 71 + .../MLGetContextManagementTemplateAction.java | 17 + ...MLGetContextManagementTemplateRequest.java | 74 + ...LGetContextManagementTemplateResponse.java | 62 + ...LListContextManagementTemplatesAction.java | 17 + ...ListContextManagementTemplatesRequest.java | 74 + ...istContextManagementTemplatesResponse.java | 77 + .../resources/index-mappings/ml_agent.json | 18 + .../ml_context_management_templates.json | 26 + .../ml/common/agent/MLAgentTest.java | 369 +++- .../CharacterBasedTokenCounterTest.java | 164 ++ .../agent/MLAgentGetResponseTest.java | 4 +- .../agent/MLRegisterAgentRequestTest.java | 261 +++ .../ml/engine/agents/AgentContextUtil.java | 179 ++ .../algorithms/agent/MLAgentExecutor.java | 283 ++- .../algorithms/agent/MLChatAgentRunner.java | 128 +- .../MLPlanExecuteAndReflectAgentRunner.java | 52 +- .../contextmanager/SlidingWindowManager.java | 191 ++ .../contextmanager/SummarizationManager.java | 442 +++++ .../ToolsOutputTruncateManager.java | 134 ++ .../algorithms/agent/MLAgentExecutorTest.java | 1668 ++--------------- .../agent/MLChatAgentRunnerTest.java | 8 +- ...LPlanExecuteAndReflectAgentRunnerTest.java | 4 +- .../SlidingWindowManagerTest.java | 234 +++ .../SummarizationManagerTest.java | 326 ++++ plugin/build.gradle | 10 + .../agent/MLAgentRegistrationValidator.java | 261 +++ .../agents/TransportRegisterAgentAction.java | 60 +- .../ContextManagementIndexUtils.java | 96 + .../ContextManagementTemplateService.java | 316 ++++ .../ContextManagerFactory.java | 120 ++ ...textManagementTemplateTransportAction.java | 67 + ...textManagementTemplateTransportAction.java | 67 + ...textManagementTemplateTransportAction.java | 67 + ...extManagementTemplatesTransportAction.java | 63 + .../ml/plugin/MachineLearningPlugin.java | 52 +- ...CreateContextManagementTemplateAction.java | 89 + ...DeleteContextManagementTemplateAction.java | 79 + ...tMLGetContextManagementTemplateAction.java | 78 + ...LListContextManagementTemplatesAction.java | 70 + .../ml/task/MLExecuteTaskRunner.java | 259 ++- .../opensearch/ml/utils/RestActionUtils.java | 1 + .../MLAgentRegistrationValidatorTests.java | 413 ++++ .../DeleteAgentTransportActionTests.java | 2 + .../agents/GetAgentTransportActionTests.java | 2 + .../RegisterAgentTransportActionTests.java | 10 +- .../ContextManagementIndexUtilsTests.java | 231 +++ ...ContextManagementTemplateServiceTests.java | 351 ++++ .../ContextManagerFactoryTests.java | 167 ++ ...anagementTemplateTransportActionTests.java | 196 ++ ...anagementTemplateTransportActionTests.java | 174 ++ ...anagementTemplateTransportActionTests.java | 192 ++ ...nagementTemplatesTransportActionTests.java | 235 +++ ...eContextManagementTemplateActionTests.java | 216 +++ ...eContextManagementTemplateActionTests.java | 181 ++ ...tContextManagementTemplateActionTests.java | 173 ++ ...ContextManagementTemplatesActionTests.java | 184 ++ .../ml/task/MLExecuteTaskRunnerTests.java | 10 +- 89 files changed, 9840 insertions(+), 1551 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java create mode 100644 common/src/main/resources/index-mappings/ml_context_management_templates.json create mode 100644 common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 5570b14ee1..0351f2cdef 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -54,6 +54,7 @@ public class CommonValue { public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job"; public static final String MCP_SESSION_MANAGEMENT_INDEX = ".plugins-ml-mcp-session-management"; public static final String MCP_TOOLS_INDEX = ".plugins-ml-mcp-tools"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX = ".plugins-ml-context-management-templates"; // index created in 3.1 to track all ml jobs created via job scheduler public static final String ML_JOBS_INDEX = ".plugins-ml-jobs"; public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); @@ -76,6 +77,7 @@ public class CommonValue { public static final String ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH = "index-mappings/ml_memory_long_term_history.json"; public static final String ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_session_management.json"; public static final String ML_MCP_TOOLS_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_tools.json"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX_MAPPING_PATH = "index-mappings/ml_context_management_templates.json"; public static final String ML_JOBS_INDEX_MAPPING_PATH = "index-mappings/ml_jobs.json"; public static final String ML_INDEX_INSIGHT_CONFIG_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_config.json"; public static final String ML_INDEX_INSIGHT_STORAGE_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_storage.json"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..5e512a394f 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.telemetry.metrics.tags.Tags; import lombok.Builder; @@ -51,6 +52,8 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; public static final String APP_TYPE_FIELD = "app_type"; public static final String IS_HIDDEN_FIELD = "is_hidden"; + public static final String CONTEXT_MANAGEMENT_NAME_FIELD = "context_management_name"; + public static final String CONTEXT_MANAGEMENT_FIELD = "context_management"; private static final String LLM_INTERFACE_FIELD = "_llm_interface"; private static final String TAG_VALUE_UNKNOWN = "unknown"; private static final String TAG_MEMORY_TYPE = "memory_type"; @@ -58,6 +61,7 @@ public class MLAgent implements ToXContentObject, Writeable { public static final int AGENT_NAME_MAX_LENGTH = 128; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT = CommonValue.VERSION_3_3_0; private String name; private String type; @@ -71,6 +75,8 @@ public class MLAgent implements ToXContentObject, Writeable { private Instant lastUpdateTime; private String appType; private Boolean isHidden; + private String contextManagementName; + private ContextManagementTemplate contextManagement; private final String tenantId; @Builder(toBuilder = true) @@ -86,6 +92,8 @@ public MLAgent( Instant lastUpdateTime, String appType, Boolean isHidden, + String contextManagementName, + ContextManagementTemplate contextManagement, String tenantId ) { this.name = name; @@ -100,6 +108,8 @@ public MLAgent( this.appType = appType; // is_hidden field isn't going to be set by user. It will be set by the code. this.isHidden = isHidden; + this.contextManagementName = contextManagementName; + this.contextManagement = contextManagement; this.tenantId = tenantId; validate(); } @@ -128,6 +138,17 @@ private void validate() { } } } + validateContextManagement(); + } + + private void validateContextManagement() { + if (contextManagementName != null && contextManagement != null) { + throw new IllegalArgumentException("Cannot specify both context_management_name and context_management"); + } + + if (contextManagement != null && !contextManagement.isValid()) { + throw new IllegalArgumentException("Invalid context management configuration"); + } } private void validateMLAgentType(String agentType) { @@ -171,6 +192,12 @@ public MLAgent(StreamInput input) throws IOException { if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { isHidden = input.readOptionalBoolean(); } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + contextManagementName = input.readOptionalString(); + if (input.readBoolean()) { + contextManagement = new ContextManagementTemplate(input); + } + } this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; validate(); } @@ -214,6 +241,15 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { out.writeOptionalBoolean(isHidden); } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + out.writeOptionalString(contextManagementName); + if (contextManagement != null) { + out.writeBoolean(true); + contextManagement.writeTo(out); + } else { + out.writeBoolean(false); + } + } if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { out.writeOptionalString(tenantId); } @@ -256,6 +292,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isHidden != null) { builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); } + if (contextManagementName != null) { + builder.field(CONTEXT_MANAGEMENT_NAME_FIELD, contextManagementName); + } + if (contextManagement != null) { + builder.field(CONTEXT_MANAGEMENT_FIELD, contextManagement); + } if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } @@ -283,6 +325,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid Instant lastUpdateTime = null; String appType = null; boolean isHidden = false; + String contextManagementName = null; + ContextManagementTemplate contextManagement = null; String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -329,6 +373,12 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid if (parseHidden) isHidden = parser.booleanValue(); break; + case CONTEXT_MANAGEMENT_NAME_FIELD: + contextManagementName = parser.text(); + break; + case CONTEXT_MANAGEMENT_FIELD: + contextManagement = ContextManagementTemplate.parse(parser); + break; case TENANT_ID_FIELD: tenantId = parser.textOrNull(); break; @@ -351,6 +401,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid .lastUpdateTime(lastUpdateTime) .appType(appType) .isHidden(isHidden) + .contextManagementName(contextManagementName) + .contextManagement(contextManagement) .tenantId(tenantId) .build(); } @@ -384,4 +436,47 @@ public Tags getTags() { return tags; } + + /** + * Check if this agent has context management configuration + * @return true if agent has either context management name or inline configuration + */ + public boolean hasContextManagement() { + return contextManagementName != null || contextManagement != null; + } + + /** + * Get the effective context management configuration for this agent. + * This method prioritizes inline configuration over template reference. + * Note: Template resolution requires external service call and should be handled by the caller. + * + * @return the inline context management configuration, or null if using template reference or no configuration + */ + public ContextManagementTemplate getInlineContextManagement() { + return contextManagement; + } + + /** + * Check if this agent uses a context management template reference + * @return true if agent references a context management template by name + */ + public boolean hasContextManagementTemplate() { + return contextManagementName != null; + } + + /** + * Check if this agent has inline context management configuration + * @return true if agent has inline context management configuration + */ + public boolean hasInlineContextManagement() { + return contextManagement != null; + } + + /** + * Get the context management template name if this agent references one + * @return the template name, or null if no template reference + */ + public String getContextManagementTemplateName() { + return contextManagementName; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java new file mode 100644 index 0000000000..c1529e6eda --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for activation rules that determine when a context manager should execute. + * Activation rules evaluate runtime conditions based on the current context state. + */ +public interface ActivationRule { + + /** + * Evaluate whether the activation condition is met. + * @param context the current context state + * @return true if the condition is met and the manager should activate, false otherwise + */ + boolean evaluate(ContextManagerContext context); + + /** + * Get a description of this activation rule for logging and debugging. + * @return a human-readable description of the rule + */ + String getDescription(); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java new file mode 100644 index 0000000000..ccae1a4bf7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory class for creating activation rules from configuration. + * Supports creating rules from configuration maps and combining multiple rules. + */ +@Log4j2 +public class ActivationRuleFactory { + + public static final String TOKENS_EXCEED_KEY = "tokens_exceed"; + public static final String MESSAGE_COUNT_EXCEED_KEY = "message_count_exceed"; + + /** + * Create activation rules from a configuration map. + * @param activationConfig the configuration map containing rule definitions + * @return a list of activation rules, or empty list if no valid rules found + */ + public static List createRules(Map activationConfig) { + List rules = new ArrayList<>(); + + if (activationConfig == null || activationConfig.isEmpty()) { + return rules; + } + + // Create tokens_exceed rule + if (activationConfig.containsKey(TOKENS_EXCEED_KEY)) { + try { + Object tokenValue = activationConfig.get(TOKENS_EXCEED_KEY); + int tokenThreshold = parseIntegerValue(tokenValue, TOKENS_EXCEED_KEY); + if (tokenThreshold > 0) { + rules.add(new TokensExceedRule(tokenThreshold)); + log.debug("Created TokensExceedRule with threshold: {}", tokenThreshold); + } else { + throw new IllegalArgumentException("Invalid token threshold value: " + tokenValue + ". Must be positive integer."); + } + } catch (Exception e) { + log.error("Failed to create TokensExceedRule: {}", e.getMessage()); + } + } + + // Create message_count_exceed rule + if (activationConfig.containsKey(MESSAGE_COUNT_EXCEED_KEY)) { + try { + Object messageValue = activationConfig.get(MESSAGE_COUNT_EXCEED_KEY); + int messageThreshold = parseIntegerValue(messageValue, MESSAGE_COUNT_EXCEED_KEY); + if (messageThreshold > 0) { + rules.add(new MessageCountExceedRule(messageThreshold)); + log.debug("Created MessageCountExceedRule with threshold: {}", messageThreshold); + } else { + throw new IllegalArgumentException( + "Invalid message count threshold value: " + messageValue + ". Must be positive integer." + ); + } + } catch (Exception e) { + log.error("Failed to create MessageCountExceedRule: {}", e.getMessage()); + } + } + + return rules; + } + + /** + * Create a composite rule that requires ALL rules to be satisfied (AND logic). + * @param rules the list of rules to combine + * @return a composite rule, or null if the list is empty + */ + public static ActivationRule createCompositeRule(List rules) { + if (rules == null || rules.isEmpty()) { + return null; + } + + if (rules.size() == 1) { + return rules.get(0); + } + + return new CompositeActivationRule(rules); + } + + /** + * Parse an integer value from configuration, handling various input types. + * @param value the value to parse + * @param fieldName the field name for error reporting + * @return the parsed integer value + * @throws IllegalArgumentException if the value cannot be parsed + */ + private static int parseIntegerValue(Object value, String fieldName) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid integer value for " + fieldName + ": " + value); + } + } else { + throw new IllegalArgumentException("Unsupported value type for " + fieldName + ": " + value.getClass().getSimpleName()); + } + } + + /** + * Composite activation rule that implements AND logic for multiple rules. + */ + private static class CompositeActivationRule implements ActivationRule { + private final List rules; + + public CompositeActivationRule(List rules) { + this.rules = new ArrayList<>(rules); + } + + @Override + public boolean evaluate(ContextManagerContext context) { + // All rules must evaluate to true (AND logic) + for (ActivationRule rule : rules) { + if (!rule.evaluate(context)) { + return false; + } + } + return true; + } + + @Override + public String getDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("composite_rule: ["); + for (int i = 0; i < rules.size(); i++) { + if (i > 0) { + sb.append(" AND "); + } + sb.append(rules.get(i).getDescription()); + } + sb.append("]"); + return sb.toString(); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java new file mode 100644 index 0000000000..e9b87a20bc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.extern.log4j.Log4j2; + +/** + * Character-based token counter implementation. + * Uses a simple heuristic of approximately 4 characters per token. + * This is a fallback implementation when more sophisticated token counting is not available. + */ +@Log4j2 +public class CharacterBasedTokenCounter implements TokenCounter { + + private static final double CHARS_PER_TOKEN = 4.0; + + @Override + public int count(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + return (int) Math.ceil(text.length() / CHARS_PER_TOKEN); + } + + @Override + public String truncateFromEnd(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(0, maxChars); + } + + @Override + public String truncateFromBeginning(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(text.length() - maxChars); + } + + @Override + public String truncateMiddle(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + // Keep equal portions from beginning and end + int halfChars = maxChars / 2; + String beginning = text.substring(0, halfChars); + String end = text.substring(text.length() - halfChars); + + return beginning + end; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java new file mode 100644 index 0000000000..40969b8c9a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context Management Template defines which context managers to use and when. + * This class represents a registered configuration that can be applied to + * agent execution to enable dynamic context optimization. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder(toBuilder = true) +public class ContextManagementTemplate implements ToXContentObject, Writeable { + + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String HOOKS_FIELD = "hooks"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_MODIFIED_FIELD = "last_modified"; + public static final String CREATED_BY_FIELD = "created_by"; + + /** + * Unique name for the context management template + */ + private String name; + + /** + * Human-readable description of what this template does + */ + private String description; + + /** + * Map of hook names to lists of context manager configurations + */ + private Map> hooks; + + /** + * When this template was created + */ + private Instant createdTime; + + /** + * When this template was last modified + */ + private Instant lastModified; + + /** + * Who created this template + */ + private String createdBy; + + /** + * Constructor from StreamInput + */ + public ContextManagementTemplate(StreamInput input) throws IOException { + this.name = input.readString(); + this.description = input.readOptionalString(); + + // Read hooks map + int hooksSize = input.readInt(); + if (hooksSize > 0) { + this.hooks = input.readMap(StreamInput::readString, in -> { + try { + int listSize = in.readInt(); + List configs = new java.util.ArrayList<>(); + for (int i = 0; i < listSize; i++) { + configs.add(new ContextManagerConfig(in)); + } + return configs; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + this.createdTime = input.readOptionalInstant(); + this.lastModified = input.readOptionalInstant(); + this.createdBy = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeOptionalString(description); + + // Write hooks map + if (hooks != null) { + out.writeInt(hooks.size()); + out.writeMap(hooks, StreamOutput::writeString, (output, configs) -> { + try { + output.writeInt(configs.size()); + for (ContextManagerConfig config : configs) { + config.writeTo(output); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } else { + out.writeInt(0); + } + + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastModified); + out.writeOptionalString(createdBy); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (name != null) { + builder.field(NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (hooks != null && !hooks.isEmpty()) { + builder.field(HOOKS_FIELD, hooks); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastModified != null) { + builder.field(LAST_MODIFIED_FIELD, lastModified.toEpochMilli()); + } + if (createdBy != null) { + builder.field(CREATED_BY_FIELD, createdBy); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagementTemplate from XContentParser + */ + public static ContextManagementTemplate parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + Map> hooks = null; + Instant createdTime = null; + Instant lastModified = null; + String createdBy = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case HOOKS_FIELD: + hooks = parseHooks(parser); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_MODIFIED_FIELD: + lastModified = Instant.ofEpochMilli(parser.longValue()); + break; + case CREATED_BY_FIELD: + createdBy = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return ContextManagementTemplate + .builder() + .name(name) + .description(description) + .hooks(hooks) + .createdTime(createdTime) + .lastModified(lastModified) + .createdBy(createdBy) + .build(); + } + + /** + * Parse hooks configuration from XContentParser + */ + private static Map> parseHooks(XContentParser parser) throws IOException { + Map> hooks = new java.util.HashMap<>(); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String hookName = parser.currentName(); + parser.nextToken(); // Move to START_ARRAY + + List configs = new java.util.ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + configs.add(ContextManagerConfig.parse(parser)); + } + + hooks.put(hookName, configs); + } + + return hooks; + } + + /** + * Validate the template configuration + */ + public boolean isValid() { + if (name == null || name.trim().isEmpty()) { + return false; + } + + // Allow null hooks (no context management) but not empty hooks map (misconfiguration) + if (hooks != null) { + if (hooks.isEmpty()) { + return false; + } + + // Validate all context manager configs + for (List configs : hooks.values()) { + if (configs != null) { + for (ContextManagerConfig config : configs) { + if (!config.isValid()) { + return false; + } + } + } + } + } + + return true; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java new file mode 100644 index 0000000000..325f98900a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.Map; + +/** + * Base interface for all context managers. + * Context managers are pluggable components that inspect and transform + * agent context components before they are sent to an LLM. + */ +public interface ContextManager { + + /** + * Get the type identifier for this context manager + * @return String identifying the manager type + */ + String getType(); + + /** + * Initialize the context manager with configuration + * @param config Configuration map for the manager + */ + void initialize(Map config); + + /** + * Check if this context manager should activate based on current context + * @param context The current context manager context + * @return true if the manager should execute, false otherwise + */ + boolean shouldActivate(ContextManagerContext context); + + /** + * Execute the context transformation + * @param context The context manager context to transform + */ + void execute(ContextManagerContext context); + +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java new file mode 100644 index 0000000000..92755cb243 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Configuration for a context manager within a context management template. + * This class holds the configuration details for how a specific context manager + * should be configured and when it should activate. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerConfig implements ToXContentObject, Writeable { + + public static final String TYPE_FIELD = "type"; + public static final String ACTIVATION_FIELD = "activation"; + public static final String CONFIG_FIELD = "config"; + + /** + * The type of context manager (e.g., "ToolsOutputTruncateManager") + */ + private String type; + + /** + * Activation conditions that determine when this manager should execute + */ + private Map activation; + + /** + * Configuration parameters specific to this manager type + */ + private Map config; + + /** + * Constructor from StreamInput + */ + public ContextManagerConfig(StreamInput input) throws IOException { + this.type = input.readString(); + this.activation = input.readMap(); + this.config = input.readMap(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeMap(activation); + out.writeMap(config); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (activation != null && !activation.isEmpty()) { + builder.field(ACTIVATION_FIELD, activation); + } + if (config != null && !config.isEmpty()) { + builder.field(CONFIG_FIELD, config); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagerConfig from XContentParser + */ + public static ContextManagerConfig parse(XContentParser parser) throws IOException { + String type = null; + Map activation = null; + Map config = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case ACTIVATION_FIELD: + activation = parser.map(); + break; + case CONFIG_FIELD: + config = parser.map(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new ContextManagerConfig(type, activation, config); + } + + /** + * Validate the configuration + * @return true if configuration is valid, false otherwise + */ + public boolean isValid() { + return type != null && !type.trim().isEmpty(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java new file mode 100644 index 0000000000..9854b78dba --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context object that contains all components of the agent execution context. + * This object is passed to context managers for inspection and transformation. + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerContext { + + /** + * The invocation state from the hook system + */ + private Map invocationState; + + /** + * The system prompt for the LLM + */ + private String systemPrompt; + + /** + * The chat history as a list of interactions + */ + @Builder.Default + private List chatHistory = new ArrayList<>(); + + /** + * The current user prompt/input + */ + private String userPrompt; + + /** + * The tool configurations available to the agent + */ + @Builder.Default + private List toolConfigs = new ArrayList<>(); + + /** + * The tool interactions/results from tool executions + */ + @Builder.Default + private List toolInteractions = new ArrayList<>(); + + /** + * Additional parameters for context processing + */ + @Builder.Default + private Map parameters = new HashMap<>(); + + /** + * Get the total token count for the current context. + * This is a utility method that can be used by context managers. + * @return estimated token count + */ + public int getEstimatedTokenCount() { + int tokenCount = 0; + + // Estimate tokens for system prompt + if (systemPrompt != null) { + tokenCount += estimateTokens(systemPrompt); + } + + // Estimate tokens for user prompt + if (userPrompt != null) { + tokenCount += estimateTokens(userPrompt); + } + + // Estimate tokens for chat history + for (Interaction interaction : chatHistory) { + if (interaction.getInput() != null) { + tokenCount += estimateTokens(interaction.getInput()); + } + if (interaction.getResponse() != null) { + tokenCount += estimateTokens(interaction.getResponse()); + } + } + + // Estimate tokens for tool interactions + for (String interaction : toolInteractions) { + tokenCount += estimateTokens(interaction); + } + + return tokenCount; + } + + /** + * Get the message count in chat history. + * @return number of messages in chat history + */ + public int getMessageCount() { + return chatHistory.size(); + } + + /** + * Simple token estimation based on character count. + * This is a fallback method - more sophisticated token counting should be implemented + * in dedicated TokenCounter implementations. + * @param text the text to estimate tokens for + * @return estimated token count + */ + private int estimateTokens(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + // Rough estimation: 1 token per 4 characters + return (int) Math.ceil(text.length() / 4.0); + } + + /** + * Add a tool interaction to the context. + * @param interaction the tool interaction to add + */ + public void addToolInteraction(String interaction) { + if (toolInteractions == null) { + toolInteractions = new ArrayList<>(); + } + toolInteractions.add(interaction); + } + + /** + * Add an interaction to the chat history. + * @param interaction the interaction to add + */ + public void addChatHistoryInteraction(Interaction interaction) { + if (chatHistory == null) { + chatHistory = new ArrayList<>(); + } + chatHistory.add(interaction); + } + + /** + * Clear the chat history. + */ + public void clearChatHistory() { + if (chatHistory != null) { + chatHistory.clear(); + } + } + + /** + * Get a parameter value by key. + * @param key the parameter key + * @return the parameter value, or null if not found + */ + public Object getParameter(String key) { + return parameters != null ? parameters.get(key) : null; + } + + /** + * Set a parameter value. + * @param key the parameter key + * @param value the parameter value + */ + public void setParameter(String key, String value) { + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(key, value); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java new file mode 100644 index 0000000000..dfef018c87 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PostMemoryEvent; +import org.opensearch.ml.common.hooks.PreLLMEvent; + +import lombok.extern.log4j.Log4j2; + +/** + * Hook provider that integrates context managers with the hook registry. + * This class manages the execution of context managers based on hook events. + */ +@Log4j2 +public class ContextManagerHookProvider implements HookProvider { + private final List contextManagers; + private final Map> hookToManagersMap; + + /** + * Constructor for ContextManagerHookProvider + * @param contextManagers List of context managers to register + */ + public ContextManagerHookProvider(List contextManagers) { + this.contextManagers = new ArrayList<>(contextManagers); + this.hookToManagersMap = new HashMap<>(); + + // Group managers by hook type based on their configuration + // This would typically be done based on the template configuration + // For now, we'll organize them by common hook types + organizeManagersByHook(); + } + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + @Override + public void registerHooks(HookRegistry registry) { + // Only register callbacks for hooks that have managers configured + if (hookToManagersMap.containsKey("PRE_LLM")) { + registry.addCallback(PreLLMEvent.class, this::handlePreLLM); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registry.addCallback(EnhancedPostToolEvent.class, this::handlePostTool); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registry.addCallback(PostMemoryEvent.class, this::handlePostMemory); + } + + log.info("Registered context manager hooks for {} managers", contextManagers.size()); + } + + /** + * Handle PreLLM hook events + * @param event The PreLLM event + */ + private void handlePreLLM(PreLLMEvent event) { + log.debug("Handling PreLLM event"); + executeManagersForHook("PRE_LLM", event.getContext()); + } + + /** + * Handle PostTool hook events + * @param event The EnhancedPostTool event + */ + private void handlePostTool(EnhancedPostToolEvent event) { + log.debug("Handling PostTool event"); + executeManagersForHook("POST_TOOL", event.getContext()); + } + + /** + * Handle PostMemory hook events + * @param event The PostMemory event + */ + private void handlePostMemory(PostMemoryEvent event) { + log.debug("Handling PostMemory event"); + executeManagersForHook("POST_MEMORY", event.getContext()); + } + + /** + * Execute context managers for a specific hook + * @param hookName The name of the hook + * @param context The context manager context + */ + private void executeManagersForHook(String hookName, ContextManagerContext context) { + List managers = hookToManagersMap.get(hookName); + if (managers != null && !managers.isEmpty()) { + log.debug("Executing {} context managers for hook: {}", managers.size(), hookName); + + for (ContextManager manager : managers) { + try { + if (manager.shouldActivate(context)) { + log.debug("Executing context manager: {}", manager.getType()); + manager.execute(context); + log.debug("Successfully executed context manager: {}", manager.getType()); + } else { + log.debug("Context manager {} activation conditions not met, skipping", manager.getType()); + } + } catch (Exception e) { + log.error("Context manager {} failed: {}", manager.getType(), e.getMessage(), e); + // Continue with other managers even if one fails + } + } + } else { + log.debug("No context managers registered for hook: {}", hookName); + } + } + + /** + * Organize managers by hook type + * This is a simplified implementation - in practice, this would be based on + * the context management template configuration + */ + private void organizeManagersByHook() { + // For now, we'll assign managers to hooks based on their type + // This would be replaced with actual template-based configuration + for (ContextManager manager : contextManagers) { + String managerType = manager.getType(); + + // Assign managers to appropriate hooks based on their type + if ("ToolsOutputTruncateManager".equals(managerType)) { + addManagerToHook("POST_TOOL", manager); + } else if ("SlidingWindowManager".equals(managerType) || "SummarizingManager".equals(managerType)) { + addManagerToHook("POST_MEMORY", manager); + addManagerToHook("PRE_LLM", manager); + } else if ("SystemPromptAugmentationManager".equals(managerType)) { + addManagerToHook("PRE_LLM", manager); + } else { + // Default to PRE_LLM for unknown types + addManagerToHook("PRE_LLM", manager); + } + } + } + + /** + * Add a manager to a specific hook + * @param hookName The hook name + * @param manager The context manager + */ + private void addManagerToHook(String hookName, ContextManager manager) { + hookToManagersMap.computeIfAbsent(hookName, k -> new ArrayList<>()).add(manager); + log.debug("Added manager {} to hook {}", manager.getType(), hookName); + } + + /** + * Update the hook-to-managers mapping based on template configuration + * @param hookConfiguration Map of hook names to manager configurations + */ + public void updateHookConfiguration(Map> hookConfiguration) { + hookToManagersMap.clear(); + + for (Map.Entry> entry : hookConfiguration.entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + // Find the corresponding context manager + ContextManager manager = findManagerByType(config.getType()); + if (manager != null) { + addManagerToHook(hookName, manager); + } else { + log.warn("Context manager of type {} not found", config.getType()); + } + } + } + + log.info("Updated hook configuration with {} hooks", hookConfiguration.size()); + } + + /** + * Find a context manager by its type + * @param type The manager type + * @return The context manager or null if not found + */ + private ContextManager findManagerByType(String type) { + return contextManagers.stream().filter(manager -> type.equals(manager.getType())).findFirst().orElse(null); + } + + /** + * Get the number of managers registered for a specific hook + * @param hookName The hook name + * @return Number of managers + */ + public int getManagerCount(String hookName) { + List managers = hookToManagersMap.get(hookName); + return managers != null ? managers.size() : 0; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java new file mode 100644 index 0000000000..f3a4d5c57e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the chat history message count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class MessageCountExceedRule implements ActivationRule { + + private final int messageThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentMessageCount = context.getMessageCount(); + return currentMessageCount > messageThreshold; + } + + @Override + public String getDescription() { + return "message_count_exceed: " + messageThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java new file mode 100644 index 0000000000..42bbd813ee --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for counting and truncating tokens in text. + * Provides methods for accurate token counting and various truncation strategies. + */ +public interface TokenCounter { + + /** + * Count the number of tokens in the given text. + * @param text the text to count tokens for + * @return the number of tokens + */ + int count(String text); + + /** + * Truncate text from the end to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromEnd(String text, int maxTokens); + + /** + * Truncate text from the beginning to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromBeginning(String text, int maxTokens); + + /** + * Truncate text from the middle to fit within the specified token limit. + * Preserves both beginning and end portions of the text. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateMiddle(String text, int maxTokens); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java new file mode 100644 index 0000000000..e4bc0544f0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the context token count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class TokensExceedRule implements ActivationRule { + + private final int tokenThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentTokenCount = context.getEstimatedTokenCount(); + return currentTokenCount > tokenThreshold; + } + + @Override + public String getDescription() { + return "tokens_exceed: " + tokenThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java new file mode 100644 index 0000000000..b8d8cb5cc4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Context management framework for OpenSearch ML-Commons. + * + * This package provides a pluggable context management system that allows for dynamic + * optimization of LLM context windows through configurable context managers. + * + * Key components: + * - {@link org.opensearch.ml.common.contextmanager.ContextManager}: Base interface for all context managers + * - {@link org.opensearch.ml.common.contextmanager.ContextManagerContext}: Context object containing all agent execution state + * - {@link org.opensearch.ml.common.contextmanager.ActivationRule}: Interface for rules that determine when managers should execute + * - {@link org.opensearch.ml.common.contextmanager.ActivationRuleFactory}: Factory for creating activation rules from configuration + * + * The system integrates with the existing hook framework to provide seamless context optimization + * during agent execution without breaking existing functionality. + */ +package org.opensearch.ml.common.contextmanager; diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java new file mode 100644 index 0000000000..7db6341c9e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Enhanced version of PostToolEvent that includes context manager context. + * This event is triggered after tool execution and provides access to both + * tool results and the full context, allowing context managers to modify + * tool outputs and other context components. + */ +public class EnhancedPostToolEvent extends PostToolEvent { + private final ContextManagerContext context; + + /** + * Constructor for EnhancedPostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public EnhancedPostToolEvent( + List> toolResults, + Exception error, + ContextManagerContext context, + Map invocationState + ) { + super(toolResults, error, invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java new file mode 100644 index 0000000000..13e7299e01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Functional interface for handling specific hook events. + * Implementations of this interface define the behavior to execute + * when a particular hook event is triggered. + * + * @param The type of HookEvent this callback handles + */ +@FunctionalInterface +public interface HookCallback { + + /** + * Handle the hook event + * @param event The hook event to handle + */ + void handle(T event); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java new file mode 100644 index 0000000000..c7f1503b61 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +/** + * Base class for all hook events in the ML agent lifecycle. + * Hook events are strongly-typed events that carry context information + * for different stages of agent execution. + */ +public abstract class HookEvent { + private final Map invocationState; + + /** + * Constructor for HookEvent + * @param invocationState The current state of the agent invocation + */ + protected HookEvent(Map invocationState) { + this.invocationState = invocationState; + } + + /** + * Get the invocation state + * @return Map containing the current invocation state + */ + public Map getInvocationState() { + return invocationState; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java new file mode 100644 index 0000000000..d6612f6749 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Interface for providers that register hook callbacks with the HookRegistry. + * Implementations of this interface define which hooks they want to listen to + * and provide the callback implementations. + */ +public interface HookProvider { + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + void registerHooks(HookRegistry registry); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java new file mode 100644 index 0000000000..32076d0d78 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import lombok.extern.log4j.Log4j2; + +/** + * Registry for managing hook callbacks and event emission. + * This class manages the registration of callbacks for different hook event types + * and provides methods to emit events to registered callbacks. + */ +@Log4j2 +public class HookRegistry { + private final Map, List>> callbacks; + + /** + * Constructor for HookRegistry + */ + public HookRegistry() { + this.callbacks = new ConcurrentHashMap<>(); + } + + /** + * Add a callback for a specific hook event type + * @param eventType The class of the hook event + * @param callback The callback to execute when the event is emitted + * @param The type of hook event + */ + public void addCallback(Class eventType, HookCallback callback) { + callbacks.computeIfAbsent(eventType, k -> new ArrayList<>()).add(callback); + log.debug("Registered callback for event type: {}", eventType.getSimpleName()); + } + + /** + * Emit an event to all registered callbacks for that event type + * @param event The hook event to emit + * @param The type of hook event + */ + @SuppressWarnings("unchecked") + public void emit(T event) { + Class eventType = event.getClass(); + List> eventCallbacks = callbacks.get(eventType); + + log + .info( + "HookRegistry.emit() called for event type: {}, callbacks available: {}", + eventType.getSimpleName(), + eventCallbacks != null ? eventCallbacks.size() : 0 + ); + + if (eventCallbacks != null) { + log.info("Emitting {} event to {} callbacks", eventType.getSimpleName(), eventCallbacks.size()); + + for (HookCallback callback : eventCallbacks) { + try { + log.info("Executing callback: {}", callback.getClass().getSimpleName()); + ((HookCallback) callback).handle(event); + } catch (Exception e) { + log.error("Error executing hook callback for event type {}: {}", eventType.getSimpleName(), e.getMessage(), e); + // Continue with other callbacks even if one fails + } + } + } else { + log.warn("No callbacks registered for event type: {}", eventType.getSimpleName()); + } + } + + /** + * Get the number of registered callbacks for a specific event type + * @param eventType The class of the hook event + * @return Number of registered callbacks + */ + public int getCallbackCount(Class eventType) { + List> eventCallbacks = callbacks.get(eventType); + return eventCallbacks != null ? eventCallbacks.size() : 0; + } + + /** + * Clear all registered callbacks + */ + public void clear() { + callbacks.clear(); + log.debug("Cleared all hook callbacks"); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java new file mode 100644 index 0000000000..006f6e8069 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.conversation.Interaction; + +/** + * Hook event triggered after memory retrieval in the agent lifecycle. + * This event provides access to the retrieved chat history and context, + * allowing context managers to modify the memory before it's used. + */ +public class PostMemoryEvent extends HookEvent { + private final ContextManagerContext context; + private final List retrievedHistory; + + /** + * Constructor for PostMemoryEvent + * @param context The context manager context containing all context components + * @param retrievedHistory The chat history retrieved from memory + * @param invocationState The current state of the agent invocation + */ + public PostMemoryEvent(ContextManagerContext context, List retrievedHistory, Map invocationState) { + super(invocationState); + this.context = context; + this.retrievedHistory = retrievedHistory; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } + + /** + * Get the retrieved chat history + * @return List of interactions retrieved from memory + */ + public List getRetrievedHistory() { + return retrievedHistory; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java new file mode 100644 index 0000000000..609d6028da --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +/** + * Hook event triggered after tool execution in the agent lifecycle. + * This event provides access to tool results and any errors that occurred. + */ +public class PostToolEvent extends HookEvent { + private final List> toolResults; + private final Exception error; + + /** + * Constructor for PostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param invocationState The current state of the agent invocation + */ + public PostToolEvent(List> toolResults, Exception error, Map invocationState) { + super(invocationState); + this.toolResults = toolResults; + this.error = error; + } + + /** + * Get the tool execution results + * @return List of tool results + */ + public List> getToolResults() { + return toolResults; + } + + /** + * Get the error that occurred during tool execution + * @return Exception if an error occurred, null otherwise + */ + public Exception getError() { + return error; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java new file mode 100644 index 0000000000..42e0cc1d0c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.input.Input; + +public class PreInvocationEvent extends HookEvent { + private final Input input; + + public PreInvocationEvent(Input input, Map invocationState) { + super(invocationState); + this.input = input; + } + + public Input getInput() { + return input; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java new file mode 100644 index 0000000000..1b82b04512 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Hook event triggered before LLM invocation in the agent lifecycle. + * This event provides access to the context that will be sent to the LLM, + * allowing context managers to modify it before the LLM call. + */ +public class PreLLMEvent extends HookEvent { + private final ContextManagerContext context; + + /** + * Constructor for PreLLMEvent + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public PreLLMEvent(ContextManagerContext context, Map invocationState) { + super(invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index 986d6eefef..c92ca43846 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -47,6 +48,14 @@ public class AgentMLInput extends MLInput { @Setter private Boolean isAsync; + @Getter + @Setter + private HookRegistry hookRegistry; + + @Getter + @Setter + private String contextManagementName; + @Builder(builderMethodName = "AgentMLInputBuilder") public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) { this(agentId, tenantId, functionName, inputDataset, false); @@ -72,6 +81,7 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) { out.writeOptionalBoolean(isAsync); } + // Note: contextManagementName and hookRegistry are not serialized as they are runtime-only fields } public AgentMLInput(StreamInput in) throws IOException { @@ -100,7 +110,12 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE tenantId = parser.textOrNull(); break; case PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); + Map parameterObjs = parser.map(); + Map parameters = StringUtils.getParameterMap(parameterObjs); + // Extract context_management from parameters + if (parameterObjs.containsKey("context_management")) { + contextManagementName = (String) parameterObjs.get("context_management"); + } inputDataset = new RemoteInferenceInputDataSet(parameters); break; case ASYNC_FIELD: diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index 9b1b6c6a81..b32e6471d7 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -7,6 +7,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -125,13 +126,16 @@ public Boolean validate(String in, Map parameters) { guardrailModelParams.put("response_filter", responseFilter); } log.info("Guardrail resFilter: {}", responseFilter); + String tenantId = parameters != null ? parameters.get(TENANT_ID_FIELD) : null; ActionRequest request = new MLPredictionTaskRequest( modelId, RemoteInferenceMLInput .builder() .algorithm(FunctionName.REMOTE) .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) - .build() + .build(), + null, + tenantId ); client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch)); try { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java index c73f2150aa..90096044f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -48,11 +48,62 @@ public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; if (mlAgent == null) { exception = addValidationError("ML agent can't be null", exception); + } else { + // Basic validation - check for conflicting configuration (following connector pattern) + if (mlAgent.getContextManagementName() != null && mlAgent.getContextManagement() != null) { + exception = addValidationError("Cannot specify both context_management_name and context_management", exception); + } + + // Validate context management template name + if (mlAgent.getContextManagementName() != null) { + exception = validateContextManagementTemplateName(mlAgent.getContextManagementName(), exception); + } + + // Validate inline context management configuration + if (mlAgent.getContextManagement() != null) { + exception = validateInlineContextManagement(mlAgent.getContextManagement(), exception); + } + } + + return exception; + } + + private ActionRequestValidationException validateContextManagementTemplateName( + String templateName, + ActionRequestValidationException exception + ) { + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Context management template name cannot be null or empty", exception); + } else if (templateName.length() > 256) { + exception = addValidationError("Context management template name cannot exceed 256 characters", exception); + } else if (!templateName.matches("^[a-zA-Z0-9._-]+$")) { + exception = addValidationError( + "Context management template name can only contain letters, numbers, underscores, hyphens, and dots", + exception + ); } + return exception; + } + private ActionRequestValidationException validateInlineContextManagement( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement, + ActionRequestValidationException exception + ) { + if (contextManagement.getHooks() != null) { + for (String hookName : contextManagement.getHooks().keySet()) { + if (!isValidHookName(hookName)) { + exception = addValidationError("Invalid hook name: " + hookName, exception); + } + } + } return exception; } + private boolean isValidHookName(String hookName) { + // Define valid hook names based on the system's supported hooks + return hookName.equals("POST_TOOL") || hookName.equals("PRE_LLM") || hookName.equals("PRE_TOOL") || hookName.equals("POST_LLM"); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..b6116afa4f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLCreateContextManagementTemplateAction extends ActionType { + public static MLCreateContextManagementTemplateAction INSTANCE = new MLCreateContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/create"; + + private MLCreateContextManagementTemplateAction() { + super(NAME, MLCreateContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java new file mode 100644 index 0000000000..ee98607505 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLCreateContextManagementTemplateRequest extends ActionRequest { + + String templateName; + ContextManagementTemplate template; + + @Builder + public MLCreateContextManagementTemplateRequest(String templateName, ContextManagementTemplate template) { + this.templateName = templateName; + this.template = template; + } + + public MLCreateContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.template = new ContextManagementTemplate(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + if (template == null) { + exception = addValidationError("Context management template cannot be null", exception); + } else { + // Validate template structure + if (template.getName() == null || template.getName().trim().isEmpty()) { + exception = addValidationError("Template name in body cannot be null or empty", exception); + } + if (template.getHooks() == null || template.getHooks().isEmpty()) { + exception = addValidationError("Template must define at least one hook", exception); + } + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + template.writeTo(out); + } + + public static MLCreateContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateContextManagementTemplateRequest) { + return (MLCreateContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java new file mode 100644 index 0000000000..85265bb333 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLCreateContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLCreateContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLCreateContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLCreateContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateContextManagementTemplateResponse) { + return (MLCreateContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..6074891afa --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLDeleteContextManagementTemplateAction extends ActionType { + public static MLDeleteContextManagementTemplateAction INSTANCE = new MLDeleteContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/delete"; + + private MLDeleteContextManagementTemplateAction() { + super(NAME, MLDeleteContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java new file mode 100644 index 0000000000..e7b6e69200 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLDeleteContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLDeleteContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLDeleteContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLDeleteContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLDeleteContextManagementTemplateRequest) { + return (MLDeleteContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLDeleteContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java new file mode 100644 index 0000000000..415323f932 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLDeleteContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLDeleteContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLDeleteContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLDeleteContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeleteContextManagementTemplateResponse) { + return (MLDeleteContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLDeleteContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4220dafe25 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLGetContextManagementTemplateAction extends ActionType { + public static MLGetContextManagementTemplateAction INSTANCE = new MLGetContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/get"; + + private MLGetContextManagementTemplateAction() { + super(NAME, MLGetContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java new file mode 100644 index 0000000000..f8f8061868 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLGetContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLGetContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLGetContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLGetContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLGetContextManagementTemplateRequest) { + return (MLGetContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLGetContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java new file mode 100644 index 0000000000..309d4c88af --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLGetContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + + private ContextManagementTemplate template; + + public MLGetContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.template = new ContextManagementTemplate(in); + } + + public MLGetContextManagementTemplateResponse(ContextManagementTemplate template) { + this.template = template; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + template.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return template.toXContent(builder, params); + } + + public static MLGetContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLGetContextManagementTemplateResponse) { + return (MLGetContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLGetContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..2b18f92e20 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLListContextManagementTemplatesAction extends ActionType { + public static MLListContextManagementTemplatesAction INSTANCE = new MLListContextManagementTemplatesAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/list"; + + private MLListContextManagementTemplatesAction() { + super(NAME, MLListContextManagementTemplatesResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java new file mode 100644 index 0000000000..7f86ad63f6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLListContextManagementTemplatesRequest extends ActionRequest { + + int from; + int size; + + @Builder + public MLListContextManagementTemplatesRequest(int from, int size) { + this.from = from; + this.size = size; + } + + public MLListContextManagementTemplatesRequest(StreamInput in) throws IOException { + super(in); + this.from = in.readInt(); + this.size = in.readInt(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + // No specific validation needed for list request + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(from); + out.writeInt(size); + } + + public static MLListContextManagementTemplatesRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLListContextManagementTemplatesRequest) { + return (MLListContextManagementTemplatesRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLListContextManagementTemplatesRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java new file mode 100644 index 0000000000..bc66395100 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLListContextManagementTemplatesResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATES_FIELD = "templates"; + public static final String TOTAL_FIELD = "total"; + + private List templates; + private long total; + + public MLListContextManagementTemplatesResponse(StreamInput in) throws IOException { + super(in); + this.templates = in.readList(ContextManagementTemplate::new); + this.total = in.readLong(); + } + + public MLListContextManagementTemplatesResponse(List templates, long total) { + this.templates = templates; + this.total = total; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(templates); + out.writeLong(total); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOTAL_FIELD, total); + builder.startArray(TEMPLATES_FIELD); + for (ContextManagementTemplate template : templates) { + template.toXContent(builder, params); + } + builder.endArray(); + builder.endObject(); + return builder; + } + + public static MLListContextManagementTemplatesResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLListContextManagementTemplatesResponse) { + return (MLListContextManagementTemplatesResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLListContextManagementTemplatesResponse", e); + } + } +} diff --git a/common/src/main/resources/index-mappings/ml_agent.json b/common/src/main/resources/index-mappings/ml_agent.json index 9d4deeca51..c530711fb2 100644 --- a/common/src/main/resources/index-mappings/ml_agent.json +++ b/common/src/main/resources/index-mappings/ml_agent.json @@ -43,6 +43,24 @@ "last_updated_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "context_management_name": { + "type": "keyword" + }, + "context_management": { + "type": "object", + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + } + } } } } diff --git a/common/src/main/resources/index-mappings/ml_context_management_templates.json b/common/src/main/resources/index-mappings/ml_context_management_templates.json new file mode 100644 index 0000000000..534be6702d --- /dev/null +++ b/common/src/main/resources/index-mappings/ml_context_management_templates.json @@ -0,0 +1,26 @@ +{ + "dynamic": false, + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + }, + "created_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "last_modified": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "created_by": { + "type": "keyword" + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index da2f5f5c1e..cf0747603e 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -29,6 +29,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.search.SearchModule; public class MLAgentTest { @@ -65,6 +66,8 @@ public void constructor_NullName() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -86,6 +89,8 @@ public void constructor_NullType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -107,6 +112,8 @@ public void constructor_NullLLMSpec() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -128,6 +135,8 @@ public void constructor_DuplicateTool() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -146,6 +155,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -174,6 +185,8 @@ public void writeTo_NullLLM() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -197,6 +210,8 @@ public void writeTo_NullTools() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -220,6 +235,8 @@ public void writeTo_NullParameters() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -243,6 +260,8 @@ public void writeTo_NullMemory() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -279,6 +298,8 @@ public void toXContent() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -336,6 +357,8 @@ public void fromStream() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -367,6 +390,8 @@ public void constructor_InvalidAgentType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -386,6 +411,8 @@ public void constructor_NonConversationalNoLLM() { Instant.EPOCH, "test", false, + null, + null, null ); assertNotNull(agent); // Ensuring object creation was successful without throwing an exception @@ -396,7 +423,22 @@ public void constructor_NonConversationalNoLLM() { @Test public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOException { - MLAgent agent = new MLAgent("test", "FLOW", "test", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", true, null); + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + true, + null, + null, + null + ); // Serialize and deserialize with an older version BytesStreamOutput output = new BytesStreamOutput(); @@ -460,6 +502,8 @@ public void getTags() { Instant.EPOCH, "test_app", true, + null, + null, null ); @@ -486,6 +530,8 @@ public void getTags_NullValues() { Instant.EPOCH, "test_app", null, + null, + null, null ); @@ -497,4 +543,325 @@ public void getTags_NullValues() { assertFalse(tagsMap.containsKey("memory_type")); assertFalse(tagsMap.containsKey("_llm_interface")); } + + @Test + public void constructor_ConflictingContextManagement() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Cannot specify both context_management_name and context_management"); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + new ContextManagementTemplate(), + null + ); + } + + @Test + public void hasContextManagement_WithTemplateName() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_WithInlineConfig() { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertEquals(template, agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_NoContextManagement() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + null, + null + ); + + assertFalse(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertEquals("template_name", deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertTrue(deserializedAgent.hasContextManagement()); + assertTrue(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementInline() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNotNull(deserializedAgent.getInlineContextManagement()); + assertEquals("test_template", deserializedAgent.getInlineContextManagement().getName()); + assertEquals("test description", deserializedAgent.getInlineContextManagement().getDescription()); + assertTrue(deserializedAgent.hasContextManagement()); + assertFalse(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagement_VersionCompatibility() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + // Serialize with older version (before context management support) + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_2_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_2_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + // Context management fields should be null for older versions + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertFalse(deserializedAgent.hasContextManagement()); + } + + @Test + public void parse_WithContextManagementName() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management_name\":\"template_name\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + } + + @Test + public void parse_WithInlineContextManagement() throws IOException { + String jsonStr = + "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management\":{\"name\":\"inline_template\",\"description\":\"test\",\"hooks\":{\"POST_TOOL\":[]}}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertNull(agent.getContextManagementTemplateName()); + assertNotNull(agent.getInlineContextManagement()); + assertEquals("inline_template", agent.getInlineContextManagement().getName()); + assertEquals("test", agent.getInlineContextManagement().getDescription()); + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + } + + @Test + public void toXContent_WithContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertTrue(content.contains("\"context_management_name\":\"template_name\"")); + assertFalse(content.contains("\"context_management\":")); + } + + @Test + public void toXContent_WithInlineContextManagement() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("inline_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertFalse(content.contains("\"context_management_name\":")); + assertTrue(content.contains("\"context_management\":")); + assertTrue(content.contains("\"inline_template\"")); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java new file mode 100644 index 0000000000..8eb5d7978f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for CharacterBasedTokenCounter. + */ +public class CharacterBasedTokenCounterTest { + + private CharacterBasedTokenCounter tokenCounter; + + @Before + public void setUp() { + tokenCounter = new CharacterBasedTokenCounter(); + } + + @Test + public void testCountWithNullText() { + Assert.assertEquals(0, tokenCounter.count(null)); + } + + @Test + public void testCountWithEmptyText() { + Assert.assertEquals(0, tokenCounter.count("")); + } + + @Test + public void testCountWithShortText() { + String text = "Hi"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithMediumText() { + String text = "This is a test message"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithLongText() { + String text = "This is a very long text that should result in multiple tokens when counted using the character-based approach."; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testTruncateFromEndWithNullText() { + Assert.assertNull(tokenCounter.truncateFromEnd(null, 10)); + } + + @Test + public void testTruncateFromEndWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromEnd("", 10)); + } + + @Test + public void testTruncateFromEndWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromEnd(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromEndWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromEnd(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.startsWith(result)); + } + + @Test + public void testTruncateFromBeginningWithNullText() { + Assert.assertNull(tokenCounter.truncateFromBeginning(null, 10)); + } + + @Test + public void testTruncateFromBeginningWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromBeginning("", 10)); + } + + @Test + public void testTruncateFromBeginningWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromBeginning(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromBeginningWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromBeginning(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.endsWith(result)); + } + + @Test + public void testTruncateMiddleWithNullText() { + Assert.assertNull(tokenCounter.truncateMiddle(null, 10)); + } + + @Test + public void testTruncateMiddleWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateMiddle("", 10)); + } + + @Test + public void testTruncateMiddleWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateMiddle(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateMiddleWithLongText() { + String text = "This is a very long text that needs to be truncated from the middle"; + String result = tokenCounter.truncateMiddle(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + + // Result should contain parts from both beginning and end + int halfChars = (5 * 4) / 2; + String expectedBeginning = text.substring(0, halfChars); + String expectedEnd = text.substring(text.length() - halfChars); + + Assert.assertTrue(result.startsWith(expectedBeginning)); + Assert.assertTrue(result.endsWith(expectedEnd)); + } + + @Test + public void testTruncateConsistency() { + String text = "This is a test text for truncation consistency"; + int maxTokens = 3; + + String fromEnd = tokenCounter.truncateFromEnd(text, maxTokens); + String fromBeginning = tokenCounter.truncateFromBeginning(text, maxTokens); + String fromMiddle = tokenCounter.truncateMiddle(text, maxTokens); + + // All truncated results should have similar token counts + int tokensFromEnd = tokenCounter.count(fromEnd); + int tokensFromBeginning = tokenCounter.count(fromBeginning); + int tokensFromMiddle = tokenCounter.count(fromMiddle); + + Assert.assertTrue(tokensFromEnd <= maxTokens); + Assert.assertTrue(tokensFromBeginning <= maxTokens); + Assert.assertTrue(tokensFromMiddle <= maxTokens); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 81a173dfde..58921f71b8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -96,6 +96,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); @@ -115,7 +117,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null); + mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null, null, null); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java index 94cbbeb7dd..da7f3d5623 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -10,6 +10,9 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -21,6 +24,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; public class MLRegisterAgentRequestTest { @@ -111,4 +116,260 @@ public void writeTo(StreamOutput out) throws IOException { }; MLRegisterAgentRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_ContextManagementConflict() { + // Create agent with both context management name and inline configuration + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + // This should throw an exception during MLAgent construction + try { + MLAgent agentWithConflict = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Cannot specify both context_management_name and context_management")); + } + } + + @Test + public void validate_ContextManagementTemplateName_Valid() { + MLAgent agentWithTemplateName = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("valid_template_name") + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithTemplateName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Empty() { + // Test empty template name - this should be caught at request validation level + MLAgent agentWithEmptyName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithEmptyName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot be null or empty")); + } + + @Test + public void validate_ContextManagementTemplateName_TooLong() { + // Test template name that's too long + String longName = "a".repeat(257); + MLAgent agentWithLongName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithLongName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void validate_ContextManagementTemplateName_InvalidCharacters() { + // Test template name with invalid characters + MLAgent agentWithInvalidName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue( + exception + .toString() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + } + + @Test + public void validate_InlineContextManagement_Valid() { + ContextManagementTemplate validContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + MLAgent agentWithInlineConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(validContextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInlineConfig); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_InvalidHookName() { + // Create a context management template with invalid hook name but valid structure + // This should pass MLAgent validation but fail request validation + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(invalidHooks) + .build(); + + MLAgent agentWithInvalidConfig = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(invalidContextManagement) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidConfig); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void validate_InlineContextManagement_EmptyHooks() { + ContextManagementTemplate emptyHooksTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(new HashMap<>()) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithEmptyHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(emptyHooksTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_InlineContextManagement_InvalidContextManagerConfig() { + Map> hooksWithInvalidConfig = new HashMap<>(); + hooksWithInvalidConfig + .put( + "POST_TOOL", + Arrays + .asList( + new ContextManagerConfig(null, null, null) // Invalid: null type + ) + ); + + ContextManagementTemplate invalidTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(hooksWithInvalidConfig) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithInvalidConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_NoContextManagement_Valid() { + MLAgent agentWithoutContextManagement = MLAgent.builder().name("test_agent").type("flow").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithoutContextManagement); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_NullValue() { + // Test null template name - this should pass validation since null is acceptable + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Null() { + // Test null template name validation + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + // This should pass since null is handled differently than empty + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_NullHooks() { + // Test inline context management with null hooks + ContextManagementTemplate contextManagementWithNullHooks = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(null) + .build(); + + MLAgent agentWithNullHooks = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(contextManagementWithNullHooks) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullHooks); + ActionRequestValidationException exception = request.validate(); + + // Should pass since null hooks are handled gracefully + assertNull(exception); + } + + @Test + public void validate_HookName_AllValidTypes() { + // Test all valid hook names to improve branch coverage + Map> allValidHooks = new HashMap<>(); + allValidHooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + allValidHooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + allValidHooks.put("PRE_TOOL", Arrays.asList(new ContextManagerConfig("MemoryManager", null, null))); + allValidHooks.put("POST_LLM", Arrays.asList(new ContextManagerConfig("ConversationManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(allValidHooks) + .build(); + + MLAgent agentWithAllHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithAllHooks); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + /** + * Helper method to create valid hooks configuration for testing + */ + private Map> createValidHooks() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + hooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + return hooks; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java new file mode 100644 index 0000000000..da2fd985f3 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -0,0 +1,179 @@ +package org.opensearch.ml.engine.agents; + +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.SYSTEM_PROMPT_FIELD; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PreLLMEvent; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; + +public class AgentContextUtil { + private static final Logger log = LogManager.getLogger(AgentContextUtil.class); + + public static ContextManagerContext buildContextManagerContextForToolOutput( + String toolOutput, + Map parameters, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + contextParameters.put("_current_tool_output", toolOutput); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object extractProcessedToolOutput(ContextManagerContext context) { + if (context.getParameters() != null) { + return context.getParameters().get("_current_tool_output"); + } + return null; + } + + public static Object extractFromContext(ContextManagerContext context, String key) { + if (context.getParameters() != null) { + return context.getParameters().get(key); + } + return null; + } + + public static ContextManagerContext buildContextManagerContext( + Map parameters, + List interactions, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (memory instanceof ConversationIndexMemory) { + String chatHistory = parameters.get(CHAT_HISTORY); + // TODO to add chatHistory into context, currently there is no context manager working on chat_history + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + builder.toolInteractions(interactions != null ? interactions : new ArrayList<>()); + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static ContextManagerContext emitPostToolHook( + Object toolOutput, + Map parameters, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContextForToolOutput( + StringUtils.toJson(toolOutput), + parameters, + toolSpecs, + memory + ); + + if (hookRegistry != null) { + try { + if (toolOutput == null) { + log.warn("Tool output is null, skipping POST_TOOL hook"); + return context; + } + EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); + hookRegistry.emit(event); + } catch (Exception e) { + log.error("Failed to emit POST_TOOL hook event", e); + } + } + return context; + } + + public static ContextManagerContext emitPreLLMHook( + Map parameters, + List interactions, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory); + + try { + PreLLMEvent event = new PreLLMEvent(context, new HashMap<>()); + hookRegistry.emit(event); + return context; + + } catch (Exception e) { + log.error("Failed to emit PRE_LLM hook event", e); + return context; + } + } + + public static void updateParametersFromContext(Map parameters, ContextManagerContext context) { + if (context.getSystemPrompt() != null) { + parameters.put(SYSTEM_PROMPT_FIELD, context.getSystemPrompt()); + } + + if (context.getUserPrompt() != null) { + parameters.put(QUESTION, context.getUserPrompt()); + } + + if (context.getChatHistory() != null && !context.getChatHistory().isEmpty()) { + } + + if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) { + parameters.put(INTERACTIONS, ", " + String.join(", ", context.getToolInteractions())); + } + + if (context.getParameters() != null) { + for (Map.Entry entry : context.getParameters().entrySet()) { + parameters.put(entry.getKey(), entry.getValue()); + + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..39d3fe6263 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -51,7 +51,9 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.MLTaskOutput; @@ -64,6 +66,9 @@ import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.Executable; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -204,6 +209,12 @@ public void execute(Input input, ActionListener listener, TransportChann ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLAgent mlAgent = MLAgent.parse(parser); + // Use existing HookRegistry from AgentMLInput if available (set by MLExecuteTaskRunner for template + // references) + // Otherwise create a fresh HookRegistry for agent execution + final HookRegistry hookRegistry = agentMLInput.getHookRegistry() != null + ? agentMLInput.getHookRegistry() + : new HookRegistry(); if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { listener .onFailure( @@ -270,7 +281,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); }, e -> { log.error("Failed to get existing interaction for regeneration", e); @@ -287,7 +299,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); } }, ex -> { @@ -300,7 +313,8 @@ public void execute(Input input, ActionListener listener, TransportChann ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap .get(memorySpec.getType()); if (factory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can + // memoryId exists, so create returns an object with existing + // memory, therefore name can // be null factory .create( @@ -319,7 +333,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, createdMemory, - channel + channel, + hookRegistry ), ex -> { log.error("Failed to find memory with memory_id: {}", memoryId, ex); @@ -340,7 +355,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, null, - channel + channel, + hookRegistry ); } } catch (Exception e) { @@ -370,10 +386,11 @@ public void execute(Input input, ActionListener listener, TransportChann /** * save root interaction and start execute the agent - * @param listener callback listener - * @param memory memory instance + * + * @param listener callback listener + * @param memory memory instance * @param inputDataSet input - * @param mlAgent agent to run + * @param mlAgent agent to run */ private void saveRootInteractionAndExecute( ActionListener listener, @@ -384,7 +401,8 @@ private void saveRootInteractionAndExecute( List outputs, List modelTensors, MLAgent mlAgent, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); @@ -419,7 +437,8 @@ private void saveRootInteractionAndExecute( modelTensors, listener, memory, - channel + channel, + hookRegistry ), e -> { log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e); @@ -438,7 +457,8 @@ private void saveRootInteractionAndExecute( modelTensors, listener, memory, - channel + channel, + hookRegistry ); } }, ex -> { @@ -447,6 +467,224 @@ private void saveRootInteractionAndExecute( })); } + /** + * Process context management configuration and register context managers in + * hook registry + * + * @param mlAgent the ML agent with context management configuration + * @param hookRegistry the hook registry to register context managers with + * @param inputDataSet the input dataset to update with context management info + */ + private void processContextManagement(MLAgent mlAgent, HookRegistry hookRegistry, RemoteInferenceInputDataSet inputDataSet) { + try { + // Check if context_management is already specified in runtime parameters + String runtimeContextManagement = inputDataSet.getParameters().get("context_management"); + if (runtimeContextManagement != null && !runtimeContextManagement.trim().isEmpty()) { + log.info("Using runtime context management parameter: {}", runtimeContextManagement); + return; // Runtime parameter takes precedence, let MLExecuteTaskRunner handle it + } + + // Check if already processed to avoid duplicate registrations + if ("true".equals(inputDataSet.getParameters().get("context_management_processed"))) { + log.debug("Context management already processed for this execution, skipping"); + return; + } + + // Check if HookRegistry already has callbacks (from previous runtime setup) + // Don't override with inline configuration if runtime config is already active + if (hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.EnhancedPostToolEvent.class) > 0 + || hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.PreLLMEvent.class) > 0) { + log.info("HookRegistry already has active configuration, skipping inline context management"); + return; + } + + ContextManagementTemplate template = null; + String templateName = null; + + if (mlAgent.hasContextManagementTemplate()) { + // Template reference - would need to be resolved from template service + templateName = mlAgent.getContextManagementTemplateName(); + log.info("Agent '{}' has context management template reference: {}", mlAgent.getName(), templateName); + // For now, we'll pass the template name to parameters for MLExecuteTaskRunner + // to handle + inputDataSet.getParameters().put("context_management", templateName); + return; // Let MLExecuteTaskRunner handle template resolution + } else if (mlAgent.getInlineContextManagement() != null) { + // Inline template - process directly + template = mlAgent.getInlineContextManagement(); + templateName = template.getName(); + log.info("Agent '{}' has inline context management configuration: {}", mlAgent.getName(), templateName); + } + + if (template != null) { + // Process inline context management template + processInlineContextManagement(template, hookRegistry); + // Mark as processed to prevent MLExecuteTaskRunner from processing it again + inputDataSet.getParameters().put("context_management_processed", "true"); + inputDataSet.getParameters().put("context_management", templateName); + } + } catch (Exception e) { + log.error("Failed to process context management for agent '{}': {}", mlAgent.getName(), e.getMessage(), e); + // Don't fail the entire execution, just log the error + } + } + + /** + * Process inline context management template and register context managers + * + * @param template the context management template + * @param hookRegistry the hook registry to register with + */ + private void processInlineContextManagement(ContextManagementTemplate template, HookRegistry hookRegistry) { + try { + log.debug("Processing inline context management template: {}", template.getName()); + + // Fresh HookRegistry ensures no duplicate registrations + + // Create context managers from template configuration + List contextManagers = createContextManagers(template); + + if (!contextManagers.isEmpty()) { + // Create hook provider and register with hook registry + org.opensearch.ml.common.contextmanager.ContextManagerHookProvider hookProvider = + new org.opensearch.ml.common.contextmanager.ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks with the registry + hookProvider.registerHooks(hookRegistry); + + log.info("Successfully registered {} context managers from template '{}'", contextManagers.size(), template.getName()); + } else { + log.warn("No context managers created from template '{}'", template.getName()); + } + } catch (Exception e) { + log.error("Failed to process inline context management template '{}': {}", template.getName(), e.getMessage(), e); + } + } + + /** + * Create context managers from template configuration + * + * @param template the context management template + * @return list of created context managers + */ + private List createContextManagers(ContextManagementTemplate template) { + List managers = new ArrayList<>(); + + try { + // Iterate through all hooks and their configurations + for (Map.Entry> entry : template + .getHooks() + .entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + log.debug("Processing hook '{}' with {} configurations", hookName, configs.size()); + + for (org.opensearch.ml.common.contextmanager.ContextManagerConfig config : configs) { + try { + org.opensearch.ml.common.contextmanager.ContextManager manager = createContextManager(config); + if (manager != null) { + managers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } + } catch (Exception e) { + log + .error( + "Failed to create context manager of type '{}' for hook '{}': {}", + config.getType(), + hookName, + e.getMessage(), + e + ); + } + } + } + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", e.getMessage(), e); + } + + return managers; + } + + /** + * Create a single context manager from configuration + * + * @param config the context manager configuration + * @return the created context manager or null if creation failed + */ + private org.opensearch.ml.common.contextmanager.ContextManager createContextManager( + org.opensearch.ml.common.contextmanager.ContextManagerConfig config + ) { + try { + String type = config.getType(); + Map managerConfig = config.getConfig(); + + log.debug("Creating context manager of type: {}", type); + + // Create context manager based on type + switch (type) { + case "ToolsOutputTruncateManager": + return createToolsOutputTruncateManager(managerConfig); + case "SummarizationManager": + case "SummarizingManager": + return createSummarizationManager(managerConfig); + case "MemoryManager": + return createMemoryManager(managerConfig); + case "ConversationManager": + return createConversationManager(managerConfig); + default: + log.warn("Unknown context manager type: {}", type); + return null; + } + } catch (Exception e) { + log.error("Failed to create context manager: {}", e.getMessage(), e); + return null; + } + } + + /** + * Create ToolsOutputTruncateManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createToolsOutputTruncateManager(Map config) { + log.debug("Creating ToolsOutputTruncateManager with config: {}", config); + ToolsOutputTruncateManager manager = new ToolsOutputTruncateManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SummarizationManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createSummarizationManager(Map config) { + log.debug("Creating SummarizationManager with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SlidingWindowManager (used for MemoryManager type) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createMemoryManager(Map config) { + log.debug("Creating SlidingWindowManager (MemoryManager) with config: {}", config); + SlidingWindowManager manager = new SlidingWindowManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create ConversationManager (placeholder - using SummarizationManager for now) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createConversationManager(Map config) { + log.debug("Creating ConversationManager (using SummarizationManager as placeholder) with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + private void executeAgent( RemoteInferenceInputDataSet inputDataSet, MLTask mlTask, @@ -457,7 +695,8 @@ private void executeAgent( List modelTensors, ActionListener listener, ConversationIndexMemory memory, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { @@ -466,10 +705,17 @@ private void executeAgent( return; } - MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent); + // Check for agent-level context management configuration (following connector + // pattern) + if (mlAgent.hasContextManagement()) { + processContextManagement(mlAgent, hookRegistry, inputDataSet); + } + + MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent, hookRegistry); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists + // If async is true, index ML task and return the taskID. Also add memoryID to + // the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); if (memoryId != null && !memoryId.isEmpty()) { @@ -606,7 +852,7 @@ private ActionListener createAsyncTaskUpdater( } @VisibleForTesting - protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { + protected MLAgentRunner getAgentRunner(MLAgent mlAgent, HookRegistry hookRegistry) { final MLAgentType agentType = MLAgentType.from(mlAgent.getType().toUpperCase(Locale.ROOT)); switch (agentType) { case FLOW: @@ -640,7 +886,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); case PLAN_EXECUTE_AND_REFLECT: return new MLPlanExecuteAndReflectAgentRunner( @@ -651,7 +898,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); @@ -730,4 +978,5 @@ private void updateInteractionWithFailure(String interactionId, ConversationInde ); } } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..6badaf31fa 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -60,7 +60,9 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -69,6 +71,7 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; @@ -76,8 +79,6 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; -import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -135,6 +136,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private SdkClient sdkClient; private Encryptor encryptor; private StreamingWrapper streamingWrapper; + private static HookRegistry hookRegistry; public MLChatAgentRunner( Client client, @@ -145,6 +147,20 @@ public MLChatAgentRunner( Map memoryFactoryMap, SdkClient sdkClient, Encryptor encryptor + ) { + this(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap, sdkClient, encryptor, null); + } + + public MLChatAgentRunner( + Client client, + Settings settings, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Map toolFactories, + Map memoryFactoryMap, + SdkClient sdkClient, + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -154,6 +170,7 @@ public MLChatAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; } @Override @@ -194,7 +211,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener for (Interaction next : r) { String question = next.getInput(); String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, + // As we store the conversation with empty response first and then update when + // have final answer, // filter out those in-flight requests when run in parallel if (Strings.isNullOrEmpty(response)) { continue; @@ -218,7 +236,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener } params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); } else { List chatHistory = new ArrayList<>(); @@ -239,7 +258,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); } } @@ -320,7 +340,7 @@ private void runReAct( StepListener lastStepListener = firstListener; StringBuilder scratchpadBuilder = new StringBuilder(); - List interactions = new CopyOnWriteArrayList<>(); + final List interactions = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -336,6 +356,7 @@ private void runReAct( if (finalI % 2 == 0) { MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); Map modelOutput = parseLLMOutput( parameters, @@ -454,6 +475,7 @@ private void runReAct( ((ActionListener) nextStepListener).onResponse(res); } } else { + // output is now the processed output from POST_TOOL hook in runTool Object filteredOutput = filterToolOutput(lastToolParams, output); addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); @@ -466,6 +488,7 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); + // Save trace with processed output saveTraceData( conversationIndexMemory, "ReAct", @@ -482,7 +505,9 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); if (!interactions.isEmpty()) { - tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); + String interactionsStr = String.join(", ", interactions); + // Set the interactions parameter - this will be processed by context management + tmpParameters.put(INTERACTIONS, ", " + interactionsStr); } sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); @@ -518,6 +543,26 @@ private void runReAct( ); return; } + // Emit PRE_LLM hook event + if (hookRegistry != null && !interactions.isEmpty()) { + List currentToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); } @@ -530,8 +575,29 @@ private void runReAct( } } + // Emit PRE_LLM hook event for initial LLM call + List initialToolSpecs = new ArrayList<>(toolSpecMap.values()); + tmpParameters.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !interactions.isEmpty()) { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, firstListener); + } private static List createFinalAnswerTensors(List sessionId, List lastThought) { @@ -581,7 +647,9 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); + List newList = new ArrayList<>(); + newList.add(outputString); + additionalInfo.put(toolOutputKey, newList); } } } @@ -602,8 +670,17 @@ private static void runTool( try { String finalAction = action; ActionListener toolListener = ActionListener.wrap(r -> { + // Emit POST_TOOL hook event - common for all tool executions + List postToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterPostTool = AgentContextUtil + .emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); + + // Extract processed output from POST_TOOL hook + String processedToolOutput = contextAfterPostTool.getParameters().get("_current_tool_output"); + Object processedOutput = processedToolOutput != null ? processedToolOutput : r; + if (functionCalling != null) { - String outputResponse = parseResponse(filterToolOutput(toolParams, r)); + String outputResponse = parseResponse(filterToolOutput(toolParams, processedOutput)); List> toolResults = List .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse))); List llmMessages = functionCalling.supply(toolResults); @@ -614,30 +691,30 @@ private static void runTool( .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))), + Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(processedOutput))), INTERACTIONS_PREFIX ) ); } - nextStepListener.onResponse(r); + nextStepListener.onResponse(processedOutput); }, e -> { interactions .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()), + Map + .of( + TOOL_CALL_ID, + toolCallId, + "tool_response", + "Tool " + action + " failed: " + StringUtils.processTextDoc(e.getMessage()) + ), INTERACTIONS_PREFIX ) ); nextStepListener .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - finalAction, - e.getMessage().replaceAll("\\n", "\n") - ) + String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage()) ); }); if (tools.get(action) instanceof MLModelTool) { @@ -666,9 +743,13 @@ private static void runTool( } /** - * In each tool runs, it copies agent parameters, which is tmpParameters into a new set of parameter llmToolTmpParameters, - * after the tool runs, normally llmToolTmpParameters will be discarded, but for some special parameters like SCRATCHPAD_NOTES_KEY, - * some new llmToolTmpParameters produced by the tool run can opt to be copied back to tmpParameters to share across tools in the same interaction + * In each tool runs, it copies agent parameters, which is tmpParameters into a + * new set of parameter llmToolTmpParameters, + * after the tool runs, normally llmToolTmpParameters will be discarded, but for + * some special parameters like SCRATCHPAD_NOTES_KEY, + * some new llmToolTmpParameters produced by the tool run can opt to be copied + * back to tmpParameters to share across tools in the same interaction + * * @param tmpParameters * @param llmToolTmpParameters */ @@ -863,7 +944,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -933,4 +1014,5 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 762b60ca5c..23586e0020 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -21,6 +21,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.MAX_ITERATION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData; @@ -53,9 +54,11 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -69,6 +72,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; @@ -92,6 +96,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { private final Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private HookRegistry hookRegistry; // flag to track if task has been updated with executor memory ids or not private boolean taskUpdated = false; private final Map taskUpdates = new HashMap<>(); @@ -163,7 +168,8 @@ public MLPlanExecuteAndReflectAgentRunner( Map toolFactories, Map memoryFactoryMap, SdkClient sdkClient, - Encryptor encryptor + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -173,6 +179,7 @@ public MLPlanExecuteAndReflectAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; this.plannerPrompt = DEFAULT_PLANNER_PROMPT; this.plannerPromptTemplate = DEFAULT_PLANNER_PROMPT_TEMPLATE; this.reflectPrompt = DEFAULT_REFLECT_PROMPT; @@ -290,9 +297,9 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { + .create(allParams.get(USER_PROMPT_FIELD), memoryId, appType, ActionListener.wrap(memory -> { memory.getMessages(ActionListener.>wrap(interactions -> { - List completedSteps = new ArrayList<>(); + final List completedSteps = new ArrayList<>(); for (Interaction interaction : interactions) { String question = interaction.getInput(); String response = interaction.getResponse(); @@ -386,8 +393,41 @@ private void executePlanningLoop( ); return; } + MLPredictionTaskRequest request; + // Planner agent doesn't use INTERACTIONS for now, reusing the INTERACTIONS to pass over + // completedSteps to context management. + // TODO should refactor the completed steps as message array format, similar to chat agent. + + allParams.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !completedSteps.isEmpty()) { + + Map requestParams = new HashMap<>(allParams); + requestParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps)); + try { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedSteps = contextAfterEvent.getToolInteractions(); + if (updatedSteps != null && !updatedSteps.equals(completedSteps)) { + completedSteps.clear(); + completedSteps.addAll(updatedSteps); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + allParams.put(COMPLETED_STEPS_FIELD, contextInteractions); + // TODO should I always clear interactions after update the completed steps? + allParams.put(INTERACTIONS, ""); + } + } + } catch (Exception e) { + log.error("Failed to emit pre-LLM hook", e); + } - MLPredictionTaskRequest request = new MLPredictionTaskRequest( + } + + request = new MLPredictionTaskRequest( llm.getModelId(), RemoteInferenceMLInput .builder() @@ -443,6 +483,10 @@ private void executePlanningLoop( .inputDataset(RemoteInferenceInputDataSet.builder().parameters(reactParams).build()) .build(); + // Pass hookRegistry to internal agent execution + // TODO need to check if the agentInput already have the hookResgistry? + agentInput.setHookRegistry(hookRegistry); + MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, agentInput); client.execute(MLExecuteTaskAction.INSTANCE, executeRequest, ActionListener.wrap(executeResponse -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java new file mode 100644 index 0000000000..c541045aaf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements a sliding window approach for tool interactions. + * Keeps only the most recent N interactions to prevent context window overflow. + * This manager ensures proper handling of different message types while tool execution flow. + */ +@Log4j2 +public class SlidingWindowManager implements ContextManager { + + public static final String TYPE = "SlidingWindowManager"; + + // Configuration keys + private static final String MAX_MESSAGES_KEY = "max_messages"; + + // Default values + private static final int DEFAULT_MAX_MESSAGES = 20; + + private int maxMessages; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxMessages = parseIntegerConfig(config, MAX_MESSAGES_KEY, DEFAULT_MAX_MESSAGES); + + if (this.maxMessages <= 0) { + log.warn("Invalid max_messages value: {}, using default {}", this.maxMessages, DEFAULT_MAX_MESSAGES); + this.maxMessages = DEFAULT_MAX_MESSAGES; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SlidingWindowManager: maxMessages={}", maxMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + log.debug("No tool interactions to process"); + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int originalSize = interactions.size(); + + if (originalSize <= maxMessages) { + log.debug("Interactions size ({}) is within limit ({}), no truncation needed", originalSize, maxMessages); + return; + } + + // Find safe start point to avoid breaking tool pairs + int startIndex = findSafeStartPoint(interactions, originalSize - maxMessages); + + // Keep the most recent interactions from safe start point + List updatedInteractions = new ArrayList<>(interactions.subList(startIndex, originalSize)); + + // Update toolInteractions in context to keep only the most recent ones + context.setToolInteractions(updatedInteractions); + + // Update the _interactions parameter with smaller size of updated interactions + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + context.setParameters(parameters); + } + parameters.put("_interactions", ", " + String.join(", ", updatedInteractions)); + + int removedMessages = originalSize - updatedInteractions.size(); + log + .info( + "Applied sliding window: kept {} most recent interactions, removed {} older interactions", + updatedInteractions.size(), + removedMessages + ); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe start point that doesn't break assistant-tool message pairs + * Same logic as SummarizationManager but for finding start point + */ + private int findSafeStartPoint(List interactions, int targetStartPoint) { + if (targetStartPoint <= 0) { + return 0; + } + if (targetStartPoint >= interactions.size()) { + return interactions.size(); + } + + int startPoint = targetStartPoint; + + while (startPoint < interactions.size()) { + try { + String messageAtStart = interactions.get(startPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = messageAtStart.contains("toolResult"); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtStart.contains("toolUse"); + boolean nextHasToolResult = false; + if (startPoint + 1 < interactions.size()) { + nextHasToolResult = interactions.get(startPoint + 1).contains("toolResult"); + } + + if (hasToolResult || (hasToolUse && startPoint + 1 < interactions.size() && !nextHasToolResult)) { + startPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", startPoint, e.getMessage()); + startPoint++; + } + } + + return startPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java new file mode 100644 index 0000000000..38ea1b4aac --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java @@ -0,0 +1,442 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static java.lang.Math.min; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.transport.client.Client; + +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.PathNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements summarization approach for tool interactions. + * Summarizes older interactions while preserving recent ones to manage context + * window. + */ +@Log4j2 +public class SummarizationManager implements ContextManager { + + public static final String TYPE = "SummarizationManager"; + + // Configuration keys + private static final String SUMMARY_RATIO_KEY = "summary_ratio"; + private static final String PRESERVE_RECENT_MESSAGES_KEY = "preserve_recent_messages"; + private static final String SUMMARIZATION_MODEL_ID_KEY = "summarization_model_id"; + private static final String SUMMARIZATION_SYSTEM_PROMPT_KEY = "summarization_system_prompt"; + + // Default values + private static final double DEFAULT_SUMMARY_RATIO = 0.3; + private static final int DEFAULT_PRESERVE_RECENT_MESSAGES = 10; + private static final String DEFAULT_SUMMARIZATION_PROMPT = + "You are a interactions summarization agent. Summarize the provided interactions concisely while preserving key information and context."; + + protected double summaryRatio; + protected int preserveRecentMessages; + protected String summarizationModelId; + protected String summarizationSystemPrompt; + protected List activationRules; + private Client client; + + public SummarizationManager(Client client) { + this.client = client; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + this.summaryRatio = parseDoubleConfig(config, SUMMARY_RATIO_KEY, DEFAULT_SUMMARY_RATIO); + this.preserveRecentMessages = parseIntegerConfig(config, PRESERVE_RECENT_MESSAGES_KEY, DEFAULT_PRESERVE_RECENT_MESSAGES); + this.summarizationModelId = (String) config.get(SUMMARIZATION_MODEL_ID_KEY); + this.summarizationSystemPrompt = (String) config.getOrDefault(SUMMARIZATION_SYSTEM_PROMPT_KEY, DEFAULT_SUMMARIZATION_PROMPT); + + // Validate summary ratio + if (summaryRatio < 0.1 || summaryRatio > 0.8) { + log.warn("Invalid summary_ratio value: {}, using default {}", summaryRatio, DEFAULT_SUMMARY_RATIO); + this.summaryRatio = DEFAULT_SUMMARY_RATIO; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SummarizationManager: summaryRatio={}, preserveRecentMessages={}", summaryRatio, preserveRecentMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int totalMessages = interactions.size(); + + // Calculate how many messages to summarize + int messagesToSummarizeCount = Math.max(1, (int) (totalMessages * summaryRatio)); + + // Ensure we don't summarize recent messages + messagesToSummarizeCount = min(messagesToSummarizeCount, totalMessages - preserveRecentMessages); + + if (messagesToSummarizeCount <= 0) { + return; + } + + // Find a safe cut point that doesn't break assistant-tool pairs + int safeCutPoint = findSafeCutPoint(interactions, messagesToSummarizeCount); + + if (safeCutPoint <= 0) { + return; + } + + // Extract messages to summarize and remaining messages + List messagesToSummarize = new ArrayList<>(interactions.subList(0, safeCutPoint)); + List remainingMessages = new ArrayList<>(interactions.subList(safeCutPoint, totalMessages)); + + // Get model ID + String modelId = summarizationModelId; + if (modelId == null) { + Map parameters = context.getParameters(); + if (parameters != null) { + modelId = (String) parameters.get("_llm_model_id"); + } + } + + if (modelId == null) { + log.error("No model ID available for summarization"); + return; + } + + // Prepare summarization parameters + Map summarizationParameters = new HashMap<>(); + summarizationParameters.put("prompt", "Help summarize the following" + StringUtils.toJson(String.join(",", messagesToSummarize))); + summarizationParameters.put("system_prompt", summarizationSystemPrompt); + + executeSummarization(context, modelId, summarizationParameters, safeCutPoint, remainingMessages, interactions); + } + + protected void executeSummarization( + ContextManagerContext context, + String modelId, + Map summarizationParameters, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + CountDownLatch latch = new CountDownLatch(1); + + try { + // Create ML input dataset for remote inference + MLInputDataset inputDataset = RemoteInferenceInputDataSet.builder().parameters(summarizationParameters).build(); + + // Create ML input + MLInput mlInput = MLInput.builder().algorithm(REMOTE).inputDataset(inputDataset).build(); + + // Create prediction request + String tenantId = (String) context.getParameter(TENANT_ID_FIELD); + MLPredictionTaskRequest request = MLPredictionTaskRequest + .builder() + .modelId(modelId) + .mlInput(mlInput) + .tenantId(tenantId) + .build(); + + // Execute prediction + ActionListener listener = ActionListener.wrap(response -> { + try { + String summary = extractSummaryFromResponse(response, context); + processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalInteractions); + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }, e -> { + try { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }); + + client.execute(MLPredictionTaskAction.INSTANCE, request, listener); + + // Wait for summarization to complete (30 second timeout) + latch.await(30, TimeUnit.SECONDS); + + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } + } + + protected void processSummarizationResult( + ContextManagerContext context, + String summary, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + try { + // Create summarized interaction + String summarizedInteraction = "{\"role\":\"assistant\",\"content\":\"Summarized previous interactions: " + + processTextDoc(summary) + + "\"}"; + + // Update interactions: summary + remaining messages + List updatedInteractions = new ArrayList<>(); + updatedInteractions.add(summarizedInteraction); + updatedInteractions.addAll(remainingMessages); + + // Update toolInteractions in context + context.setToolInteractions(updatedInteractions); + + // Update parameters + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); + context.setParameters(parameters); + + log + .info( + "Summarization completed: {} messages summarized, {} messages preserved", + messagesToSummarizeCount, + remainingMessages.size() + ); + + } catch (Exception e) { + log.error("Failed to process summarization result", e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response, ContextManagerContext context) { + try { + MLOutput output = response.getOutput(); + if (output instanceof ModelTensorOutput) { + ModelTensorOutput tensorOutput = (ModelTensorOutput) output; + List mlModelOutputs = tensorOutput.getMlModelOutputs(); + + if (mlModelOutputs != null && !mlModelOutputs.isEmpty()) { + List tensors = mlModelOutputs.get(0).getMlModelTensors(); + if (tensors != null && !tensors.isEmpty()) { + Map dataAsMap = tensors.get(0).getDataAsMap(); + + // Use LLM_RESPONSE_FILTER from agent configuration if available + Map parameters = context.getParameters(); + if (parameters != null + && parameters.containsKey(LLM_RESPONSE_FILTER) + && !parameters.get(LLM_RESPONSE_FILTER).isEmpty()) { + try { + String responseFilter = parameters.get(LLM_RESPONSE_FILTER); + Object filteredResponse = JsonPath.read(dataAsMap, responseFilter); + if (filteredResponse instanceof String) { + String result = ((String) filteredResponse).trim(); + return result; + } else { + String result = StringUtils.toJson(filteredResponse); + return result; + } + } catch (PathNotFoundException e) { + // Fall back to default parsing + } catch (Exception e) { + // Fall back to default parsing + } + } + + // Fallback to default parsing if no filter or filter fails + if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { + Object responseObj = dataAsMap.get("response"); + if (responseObj instanceof String) { + return ((String) responseObj).trim(); + } + } + + // Last resort: return JSON representation + return StringUtils.toJson(dataAsMap); + } + } + } + } catch (Exception e) { + log.error("Failed to extract summary from response", e); + } + + return "Summary generation failed"; + } + + private double parseDoubleConfig(Map config, String key, double defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Double) { + return (Double) value; + } else if (value instanceof Number) { + return ((Number) value).doubleValue(); + } else if (value instanceof String) { + return Double.parseDouble((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid double value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe cut point that doesn't break assistant-tool message pairs + * Exact same logic as Strands agent + */ + private int findSafeCutPoint(List interactions, int targetCutPoint) { + if (targetCutPoint >= interactions.size()) { + return targetCutPoint; + } + // // the current agent logic is when odd number it's tool called result and even number is tool input, should always summarize for + // pairs, so the targetCutPoint needs to be even + // if (targetCutPoint%2==0){ + // return targetCutPoint; + // } else { + // return min(targetCutPoint+1,interactions.size()); + // } + int splitPoint = targetCutPoint; + + while (splitPoint < interactions.size()) { + try { + String messageAtSplit = interactions.get(splitPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = (messageAtSplit.contains("toolResult") || messageAtSplit.contains("tool_call_id")); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtSplit.contains("toolUse"); + boolean nextHasToolResult = false; + // TODO we need better way to handle the tool result based on the llm interfaces. + if (splitPoint + 1 < interactions.size()) { + nextHasToolResult = (interactions.get(splitPoint + 1).contains("toolResult") + || messageAtSplit.contains("tool_call_id")); + } + + if (hasToolResult || (hasToolUse && splitPoint + 1 < interactions.size() && !nextHasToolResult)) { + splitPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", splitPoint, e.getMessage()); + splitPoint++; + } + } + + return splitPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java new file mode 100644 index 0000000000..4fa97c156d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that truncates tool output to prevent context window overflow. + * This manager processes the current tool output and applies length limits. + */ +@Log4j2 +public class ToolsOutputTruncateManager implements ContextManager { + + public static final String TYPE = "ToolsOutputTruncateManager"; + + // Configuration keys + private static final String MAX_OUTPUT_LENGTH_KEY = "max_output_length"; + + // Default values + private static final int DEFAULT_MAX_OUTPUT_LENGTH = 40000; + + private int maxOutputLength; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxOutputLength = parseIntegerConfig(config, MAX_OUTPUT_LENGTH_KEY, DEFAULT_MAX_OUTPUT_LENGTH); + + if (this.maxOutputLength <= 0) { + log.warn("Invalid max_output_length value: {}, using default {}", this.maxOutputLength, DEFAULT_MAX_OUTPUT_LENGTH); + this.maxOutputLength = DEFAULT_MAX_OUTPUT_LENGTH; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized ToolsOutputTruncateManager: maxOutputLength={}", maxOutputLength); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + // Process current tool output from parameters + Map parameters = context.getParameters(); + if (parameters == null) { + log.debug("No parameters available for tool output truncation"); + return; + } + + Object currentToolOutput = parameters.get("_current_tool_output"); + if (currentToolOutput == null) { + log.debug("No current tool output to process"); + return; + } + + String outputString = currentToolOutput.toString(); + int originalLength = outputString.length(); + + if (originalLength <= maxOutputLength) { + log.debug("Tool output length ({}) is within limit ({}), no truncation needed", originalLength, maxOutputLength); + return; + } + + // Truncate the output + String truncatedOutput = outputString.substring(0, maxOutputLength); + + // Add truncation indicator + truncatedOutput += "... [Output truncated - original length: " + originalLength + " characters]"; + + // Update the current tool output in parameters + parameters.put("_current_tool_output", truncatedOutput); + + int truncatedLength = truncatedOutput.length(); + log.info("Tool output truncated: original length {} -> truncated length {}", originalLength, truncatedLength); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index efc84d8f8c..c3843434f7 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -5,1617 +5,317 @@ package org.opensearch.ml.engine.algorithms.agent; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.when; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; -import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; -import java.io.IOException; -import java.net.InetAddress; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import javax.naming.Context; - -import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; -import org.opensearch.ResourceNotFoundException; -import org.opensearch.Version; -import org.opensearch.action.get.GetRequest; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; -import org.opensearch.ml.common.agent.MLMemorySpec; -import org.opensearch.ml.common.agent.MLToolSpec; -import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; -import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; -import org.opensearch.ml.engine.memory.MLMemoryManager; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.action.conversation.GetInteractionAction; -import org.opensearch.ml.memory.action.conversation.GetInteractionResponse; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.remote.metadata.client.SdkClient; -import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; -import com.google.gson.Gson; - -import software.amazon.awssdk.utils.ImmutableMap; - +@SuppressWarnings({ "rawtypes" }) public class MLAgentExecutorTest { @Mock private Client client; - SdkClient sdkClient; - private Settings settings; - @Mock - private ClusterService clusterService; + @Mock - private ClusterState clusterState; + private SdkClient sdkClient; + @Mock - private Metadata metadata; + private ClusterService clusterService; + @Mock private NamedXContentRegistry xContentRegistry; + @Mock - private Map toolFactories; - @Mock - private Map memoryMap; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock - private IndexResponse indexResponse; + private Encryptor encryptor; + @Mock private ThreadPool threadPool; + private ThreadContext threadContext; - @Mock - private Context context; - @Mock - private ConversationIndexMemory.Factory mockMemoryFactory; - @Mock - private ActionListener agentActionListener; - @Mock - private MLAgentRunner mlAgentRunner; @Mock - private ConversationIndexMemory memory; - @Mock - private MLMemoryManager memoryManager; - private MLAgentExecutor mlAgentExecutor; + private ThreadContext.StoredContext storedContext; @Mock - private MLFeatureEnabledSetting mlFeatureEnabledSetting; - - @Captor - private ArgumentCaptor objectCaptor; + private TransportChannel channel; - @Captor - private ArgumentCaptor exceptionCaptor; + @Mock + private ActionListener listener; - private DiscoveryNode localNode = new DiscoveryNode( - "mockClusterManagerNodeId", - "mockClusterManagerNodeId", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + @Mock + private GetResponse getResponse; - MLAgent mlAgent; + private MLAgentExecutor mlAgentExecutor; + private Map toolFactories; + private Map memoryFactoryMap; + private Settings settings; @Before - @SuppressWarnings("unchecked") public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + toolFactories = new HashMap<>(); + memoryFactoryMap = new HashMap<>(); threadContext = new ThreadContext(settings); - memoryMap = ImmutableMap.of("memoryType", mockMemoryFactory); - Mockito.doAnswer(invocation -> { - MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); - MLAgent mlAgent = MLAgent.builder().name("agent").memory(mlMemorySpec).type("flow").build(); - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(true); - Mockito.when(getResponse.getSourceAsBytesRef()).thenReturn(BytesReference.bytes(content)); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.when(clusterService.state()).thenReturn(clusterState); - Mockito.when(clusterState.metadata()).thenReturn(metadata); - when(clusterService.localNode()).thenReturn(localNode); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(true); - Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(this.clusterService.getSettings()).thenReturn(settings); - when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED))); - - // Mock MLFeatureEnabledSetting when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); - when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(true); - - settings = Settings.builder().build(); - mlAgentExecutor = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - mlFeatureEnabledSetting, - null - ) - ); - - } - - @Test - public void test_NoAgentIndex() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(false); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Agent index not found"); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NullInput_ThrowsException() { - mlAgentExecutor.execute(null, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonAgentInput_ThrowsException() { - Input input = new Input() { - @Override - public FunctionName getFunctionName() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return null; - } - }; - mlAgentExecutor.execute(input, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputData_ThrowsException() { - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, null); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputParas_ThrowsException() { - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(null).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, inputDataSet); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test - public void test_HappyCase_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensor_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - List response = Arrays.asList(modelTensor1, modelTensor2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(response, output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensors_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List response = Arrays.asList(modelTensors1, modelTensors2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfString_ReturnsResult() throws IOException { - List response = Arrays.asList("response1", "response2"); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Gson gson = new Gson(); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(gson.toJson(response), output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsString_ReturnsResult() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("response"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("response", output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsModelTensorOutput_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List modelTensorsList = Arrays.asList(modelTensors1, modelTensors2); - ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(modelTensorsList).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensorOutput); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_CreateConversation_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_Regenerate_Validation() throws IOException { - Map params = new HashMap<>(); - params.put(REGENERATE_INTERACTION_ID, "foo"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert.assertEquals(exception.getMessage(), "A memory ID must be provided to regenerate."); - } - - @Test - public void test_Regenerate_GetOriginalInteraction() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(Boolean.TRUE); - return null; - }).when(memoryManager).deleteInteractionAndTrace(Mockito.anyString(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - GetInteractionResponse interactionResponse = Mockito.mock(GetInteractionResponse.class); - Interaction mockInteraction = Mockito.mock(Interaction.class); - Mockito.when(mockInteraction.getInput()).thenReturn("regenerate question"); - Mockito.when(interactionResponse.getInteraction()).thenReturn(mockInteraction); - listener.onResponse(interactionResponse); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - String interactionId = "bar-interaction"; - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, interactionId); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertEquals(params.get(QUESTION), "regenerate question"); - // original interaction got deleted - Mockito.verify(memoryManager, times(1)).deleteInteractionAndTrace(Mockito.eq(interactionId), Mockito.any()); - } - - @Test - public void test_Regenerate_OriginalInteraction_NotExist() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new ResourceNotFoundException("Interaction bar-interaction not found")); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, "bar-interaction"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertNull(params.get(QUESTION)); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Interaction bar-interaction not found"); - } - - @Test - public void test_CreateFlowAgent() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("flow").build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLFlowAgentRunner); - } - - @Test - public void test_CreateChatAgent() { - LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLAgent mlAgent = MLAgent.builder().name("test_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLChatAgentRunner); - } - - @Test(expected = IllegalArgumentException.class) - public void test_InvalidAgent_ThrowsException() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("illegal").build(); - mlAgentExecutor.getAgentRunner(mlAgent); - } - - @Test - public void test_GetModel_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException()); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - @Test - public void test_GetModelDoesNotExist_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(false); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateConversationFailure_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new RuntimeException()); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateInteractionFailure_ThrowsException() { - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onFailure(new RuntimeException()); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AgentRunnerFailure_ReturnsResult() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException()); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AsyncMode_ReturnsTaskId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task_id", 1, 0, 2, true); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput result = (MLTaskOutput) objectCaptor.getValue(); - - Assert.assertEquals("task_id", result.getTaskId()); - Assert.assertEquals("RUNNING", result.getStatus()); - } - - @Test - public void test_AsyncMode_IndexTask_failure() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new Exception("Index Not Found")); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_mcp_connector_requires_mcp_connector_enabled() throws IOException { - // Create an MLAgent with MCP connectors in parameters - Map parameters = new HashMap<>(); - parameters.put(MCP_CONNECTORS_FIELD, "[{\"connector_id\": \"test-connector\"}]"); - - MLAgent mlAgentWithMcpConnectors = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - Collections.emptyList(), - parameters, - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null + // Mock ClusterService for the agent index check + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Simulate index not found + + mlAgentExecutor = new MLAgentExecutor( + client, + sdkClient, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap, + mlFeatureEnabledSetting, + encryptor ); - - // Create GetResponse with the MLAgent that has MCP connectors - XContentBuilder content = mlAgentWithMcpConnectors.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with MCP connector disabled - MLFeatureEnabledSetting disabledMcpSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(disabledMcpSetting.isMultiTenancyEnabled()).thenReturn(false); - when(disabledMcpSetting.isMcpConnectorEnabled()).thenReturn(false); - - MLAgentExecutor mlAgentExecutorWithDisabledMcp = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - disabledMcpSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledMcp).getAgentRunner(Mockito.any()); - - // Execute the agent - mlAgentExecutorWithDisabledMcp.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution fails with the correct error message - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof OpenSearchException); - Assert.assertEquals(exception.getMessage(), ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE); } @Test - public void test_query_planning_agentic_search_enabled() throws IOException { - // Create an MLAgent with QueryPlanningTool - MLAgent mlAgentWithQueryPlanning = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "QueryPlanningTool", - "QueryPlanningTool", - "QueryPlanningTool", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null - ); - - // Create GetResponse with the MLAgent that has QueryPlanningTool - XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with agentic search enabled - MLFeatureEnabledSetting enabledSearchSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(enabledSearchSetting.isMultiTenancyEnabled()).thenReturn(false); - when(enabledSearchSetting.isMcpConnectorEnabled()).thenReturn(true); - - MLAgentExecutor mlAgentExecutorWithEnabledSearch = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - enabledSearchSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithEnabledSearch).getAgentRunner(Mockito.any()); - - // Mock successful execution - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - // Execute the agent - mlAgentExecutorWithEnabledSearch.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution succeeds - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - private AgentMLInput getAgentMLInput() { - Map params = new HashMap<>(); - params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - return new AgentMLInput("test", null, FunctionName.AGENT, dataset); - } - - public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenantId) throws IOException { - - mlAgent = new MLAgent( - "test", - MLAgentType.CONVERSATIONAL.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "memoryType", - "test", - "test", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - isHidden, - tenantId - ); - - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", agentId, 111l, 111l, 111l, true, bytesReference, null, null); - return new GetResponse(getResult); + public void testConstructor() { + assertNotNull(mlAgentExecutor); + assertEquals(client, mlAgentExecutor.getClient()); + assertEquals(settings, mlAgentExecutor.getSettings()); + assertEquals(clusterService, mlAgentExecutor.getClusterService()); + assertEquals(xContentRegistry, mlAgentExecutor.getXContentRegistry()); + assertEquals(toolFactories, mlAgentExecutor.getToolFactories()); + assertEquals(memoryFactoryMap, mlAgentExecutor.getMemoryFactoryMap()); + assertEquals(mlFeatureEnabledSetting, mlAgentExecutor.getMlFeatureEnabledSetting()); + assertEquals(encryptor, mlAgentExecutor.getEncryptor()); } @Test - public void test_BothParentAndRegenerateInteractionId_ThrowsException() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Map params = new HashMap<>(); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent-123"); - params.put(REGENERATE_INTERACTION_ID, "regenerate-456"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testOnMultiTenancyEnabledChanged() { + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + assertTrue(mlAgentExecutor.getIsMultiTenancyEnabled()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert - .assertEquals( - exception.getMessage(), - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." - ); + mlAgentExecutor.onMultiTenancyEnabledChanged(false); + assertFalse(mlAgentExecutor.getIsMultiTenancyEnabled()); } @Test - public void test_ExistingConversation_WithMemoryAndParentInteractionId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testExecuteWithWrongInputType() { + // Test with non-AgentMLInput - create a mock Input that's not AgentMLInput + Input wrongInput = mock(Input.class); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - // Verify memory factory was called with null question and existing memory_id - Mockito.verify(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); + try { + mlAgentExecutor.execute(wrongInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("wrong input", exception.getMessage()); + } } @Test - public void test_AgentFailure_UpdatesInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Agent execution failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "test-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); + public void testExecuteWithNullInputDataSet() { + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, null); - // Verify failure was propagated to listener - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertEquals(testException, exceptionCaptor.getValue()); - - // Verify interaction was updated with failure message - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertTrue(updateContent.get("response").toString().contains("Agent execution failed")); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExistingConversation_MemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory failure for existing conversation - RuntimeException memoryException = new RuntimeException("Memory creation failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); + public void testExecuteWithNullParameters() { + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExecuteAgent_SyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(false); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); + public void testExecuteWithMultiTenancyEnabledButNoTenantId() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder() + .parameters(parameters) + .build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); + + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected OpenSearchStatusException"); + } catch (OpenSearchStatusException exception) { + assertEquals("You don't have permission to access this resource", exception.getMessage()); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + } } @Test - public void test_ExecuteAgent_AsyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); + public void testExecuteWithAgentIndexNotFound() { + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + // Since we can't mock static methods easily, we'll test a different scenario + // This test would need the actual MLIndicesHandler behavior + mlAgentExecutor.execute(agentInput, listener, channel); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(true); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); - Assert.assertEquals("RUNNING", output.getStatus()); + // Verify that the listener was called (the actual behavior will depend on the implementation) + verify(listener, timeout(5000).atLeastOnce()).onFailure(any()); } @Test - public void test_UpdateInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Test failure message"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertEquals("Agent execution failed: Test failure message", updateContent.get("response")); + public void testGetAgentRunnerWithFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLFlowAgentRunner); } @Test - public void test_ConversationMemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", true, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - RuntimeException memoryException = new RuntimeException("Failed to read conversation memory"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq("test question"), Mockito.eq(null), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + public void testGetAgentRunnerWithConversationalFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL_FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLConversationalFlowAgentRunner); } @Test - public void test_AsyncExecution_NullOutput() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(null); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithConversationalAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLChatAgentRunner); } @Test - public void test_AsyncExecution_Failure() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent execution failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithPlanExecuteAndReflectAgent() { + MLAgent agent = createTestAgent(MLAgentType.PLAN_EXECUTE_AND_REFLECT.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLPlanExecuteAndReflectAgentRunner); } @Test - public void test_UpdateInteractionFailure_LogLines() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); + public void testGetAgentRunnerWithUnsupportedAgentType() { + // Create a mock MLAgent instead of using the constructor that validates + MLAgent agent = mock(MLAgent.class); + when(agent.getType()).thenReturn("UNSUPPORTED_TYPE"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); + try { + mlAgentExecutor.getAgentRunner(agent, null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Wrong Agent type", exception.getMessage()); + } } @Test - public void test_UpdateInteractionFailure_ErrorCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Update failed")); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); + public void testProcessOutputWithModelTensorOutput() throws Exception { + ModelTensorOutput output = mock(ModelTensorOutput.class); + when(output.getMlModelOutputs()).thenReturn(Collections.emptyList()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + List modelTensors = new java.util.ArrayList<>(); - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + mlAgentExecutor.processOutput(output, modelTensors); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); + verify(output).getMlModelOutputs(); } @Test - public void test_AsyncTaskUpdate_SuccessCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("success"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + public void testProcessOutputWithString() throws Exception { + String output = "test response"; + List modelTensors = new java.util.ArrayList<>(); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + mlAgentExecutor.processOutput(output, modelTensors); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + assertEquals(1, modelTensors.size()); + assertEquals("response", modelTensors.get(0).getName()); + assertEquals("test response", modelTensors.get(0).getResult()); } - @Test - public void test_AsyncTaskUpdate_FailureCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + private MLAgent createTestAgent(String type) { + return MLAgent + .builder() + .name("test-agent") + .type(type) + .description("Test agent") + .llm(LLMSpec.builder().modelId("test-model").parameters(Collections.emptyMap()).build()) + .tools(Collections.emptyList()) + .parameters(Collections.emptyMap()) + .memory(null) + .createdTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .appType("test-app") + .build(); } @Test - public void test_AgentRunnerException() throws IOException { - // Reset mocks to ensure clean state - Mockito.reset(mlAgentRunner); - - RuntimeException testException = new RuntimeException("Agent runner threw exception"); - Mockito.doThrow(testException).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + public void testContextManagementProcessedFlagPreventsReprocessing() { + // Test that the context_management_processed flag prevents duplicate processing + Map parameters = new HashMap<>(); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + // First check - should allow processing + boolean shouldProcess1 = !"true".equals(parameters.get("context_management_processed")); + assertTrue("First call should allow processing", shouldProcess1); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + // Mark as processed (simulating what the method does) + parameters.put("context_management_processed", "true"); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); + // Second check - should prevent processing + boolean shouldProcess2 = !"true".equals(parameters.get("context_management_processed")); + assertFalse("Second call should prevent processing", shouldProcess2); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..c63db9df4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 7ed4e91b1c..00a6edde13 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -136,6 +136,7 @@ public void setup() { // memory mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + when(memoryMap.get(ConversationIndexMemory.TYPE)).thenReturn(memoryFactory); when(memoryMap.get(anyString())).thenReturn(memoryFactory); when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); @@ -171,7 +172,8 @@ public void setup() { toolFactories, memoryMap, sdkClient, - encryptor + encryptor, + null ); // Setup tools diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java new file mode 100644 index 0000000000..692c0ebf7c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Unit tests for SlidingWindowManager. + */ +public class SlidingWindowManagerTest { + + private SlidingWindowManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + manager = new SlidingWindowManager(); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SlidingWindowManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should initialize with default values without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + + manager.initialize(config); + + // Should initialize without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should always activate when no rules are defined + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should handle empty tool interactions gracefully + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithSmallToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + manager.initialize(config); + + // Add fewer interactions than the limit + addToolInteractionsToContext(5); + int originalSize = context.getToolInteractions().size(); + + manager.execute(context); + + // Tool interactions should remain unchanged + Assert.assertEquals(originalSize, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithLargeToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add more interactions than the limit + addToolInteractionsToContext(10); + + manager.execute(context); + + // Tool interactions should be truncated to the limit + Assert.assertEquals(5, context.getToolInteractions().size()); + + // Parameters should be updated with truncated interactions + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + // Should contain only the last 5 interactions + String[] interactions = interactionsParam.substring(2).split(", "); // Remove ", " prefix + Assert.assertEquals(5, interactions.length); + + // Should keep the most recent interactions (6-10) + for (int i = 0; i < interactions.length; i++) { + String expected = "Tool output " + (6 + i); + Assert.assertEquals(expected, interactions[i]); + } + + // Verify toolInteractions also contain the most recent ones + for (int i = 0; i < context.getToolInteractions().size(); i++) { + String expected = "Tool output " + (6 + i); + String actual = context.getToolInteractions().get(i); + Assert.assertEquals(expected, actual); + } + } + + @Test + public void testExecuteKeepsMostRecentInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Add interactions with identifiable content + addToolInteractionsToContext(7); + + manager.execute(context); + + // Should keep the last 3 interactions (5, 6, 7) + String interactionsParam = (String) context.getParameters().get("_interactions"); + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + Assert.assertEquals("Tool output 5", interactions[0]); + Assert.assertEquals("Tool output 6", interactions[1]); + Assert.assertEquals("Tool output 7", interactions[2]); + } + + @Test + public void testExecuteWithExactLimit() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add exactly the limit number of interactions + addToolInteractionsToContext(5); + + manager.execute(context); + + // Parameters should not be updated since no truncation needed + Assert.assertNull(context.getParameters().get("_interactions")); + } + + @Test + public void testExecuteWithNullToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + context.setToolInteractions(null); + + // Should handle null tool interactions gracefully + manager.execute(context); + + // Should not throw exception + Assert.assertNull(context.getToolInteractions()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + config.put("max_messages", 1); // Set to 1 to force truncation + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should process all string interactions and set _interactions parameter + Assert.assertNotNull(context.getParameters().get("_interactions")); + Assert.assertEquals(1, context.getToolInteractions().size()); // Should keep only 1 + } + + @Test + public void testInvalidMaxMessagesConfig() { + Map config = new HashMap<>(); + config.put("max_messages", "invalid_number"); + + // Should handle invalid config gracefully and use default + manager.initialize(config); + + Assert.assertNotNull(manager); + } + + @Test + public void testExecuteWithNullParameters() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Set parameters to null + context.setParameters(null); + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should create new parameters map and update it + Assert.assertNotNull(context.getParameters()); + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java new file mode 100644 index 0000000000..6f769d2645 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java @@ -0,0 +1,326 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for SummarizationManager. + */ +public class SummarizationManagerTest { + + @Mock + private Client client; + + private SummarizationManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + manager = new SummarizationManager(client); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SummarizationManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + Assert.assertEquals(10, manager.preserveRecentMessages); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.5); + config.put("preserve_recent_messages", 5); + config.put("summarization_model_id", "test-model"); + config.put("summarization_system_prompt", "Custom prompt"); + + manager.initialize(config); + + Assert.assertEquals(0.5, manager.summaryRatio, 0.001); + Assert.assertEquals(5, manager.preserveRecentMessages); + Assert.assertEquals("test-model", manager.summarizationModelId); + Assert.assertEquals("Custom prompt", manager.summarizationSystemPrompt); + } + + @Test + public void testInitializeWithInvalidSummaryRatio() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.9); // Invalid - too high + + manager.initialize(config); + + // Should use default value + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithInsufficientMessages() { + Map config = new HashMap<>(); + config.put("preserve_recent_messages", 10); + manager.initialize(config); + + // Add only 5 interactions - not enough to summarize + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should remain unchanged + Assert.assertEquals(5, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNoModelId() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(20); + + manager.execute(context); + + // Should remain unchanged due to missing model ID + Assert.assertEquals(20, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should handle gracefully - only 2 string interactions, not enough to summarize + Assert.assertEquals(2, context.getToolInteractions().size()); + } + + @Test + public void testProcessSummarizationResult() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(10); + List remainingMessages = List.of("Message 6", "Message 7", "Message 8", "Message 9", "Message 10"); + + manager.processSummarizationResult(context, "Test summary", 5, remainingMessages, context.getToolInteractions()); + + // Should have 1 summary + 5 remaining = 6 total + Assert.assertEquals(6, context.getToolInteractions().size()); + + // First should be summary + String firstOutput = context.getToolInteractions().get(0); + Assert.assertTrue(firstOutput.contains("Test summary")); + } + + @Test + public void testExtractSummaryFromResponseWithLLMResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content"); + context.setParameters(parameters); + + // Create mock response with OpenAI-style structure + Map responseData = new HashMap<>(); + Map choice = new HashMap<>(); + Map message = new HashMap<>(); + message.put("content", "This is the extracted summary content"); + choice.put("message", message); + responseData.put("choices", List.of(choice)); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("This is the extracted summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithBedrockResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with Bedrock-style LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text"); + context.setParameters(parameters); + + // Create mock response with Bedrock-style structure + Map responseData = new HashMap<>(); + Map output = new HashMap<>(); + Map message = new HashMap<>(); + Map content = new HashMap<>(); + content.put("text", "Bedrock extracted summary"); + message.put("content", List.of(content)); + output.put("message", message); + responseData.put("output", output); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Bedrock extracted summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithInvalidFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with invalid LLM_RESPONSE_FILTER path + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.invalid.path"); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Fallback summary content"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + // Should fall back to default parsing + Assert.assertEquals("Fallback summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithoutFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Context without LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Default parsed summary"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Default parsed summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithEmptyFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with empty LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, ""); + context.setParameters(parameters); + + // Create mock response + Map responseData = new HashMap<>(); + responseData.put("response", "Empty filter fallback"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Empty filter fallback", result); + } + + /** + * Helper method to create a mock MLTaskResponse with the given data. + */ + private MLTaskResponse createMockMLTaskResponse(Map responseData) { + ModelTensor tensor = ModelTensor.builder().dataAsMap(responseData).build(); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + return MLTaskResponse.builder().output(output).build(); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index bc6f05f48f..75c5eb8931 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -648,6 +648,16 @@ configurations.all { resolutionStrategy.force "jakarta.json:jakarta.json-api:2.1.3" resolutionStrategy.force "org.opensearch:opensearch:${opensearch_version}" resolutionStrategy.force "org.bouncycastle:bcprov-jdk18on:1.78.1" + // Force consistent Netty versions to resolve conflicts + resolutionStrategy.force 'io.netty:netty-codec-http:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec-http2:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-common:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-buffer:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-handler:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-resolver:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.124.Final' resolutionStrategy.force 'io.projectreactor:reactor-core:3.7.0' resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.9.10" resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.9.23" diff --git a/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java new file mode 100644 index 0000000000..dc9ea439d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Validator for ML Agent registration that performs advanced validation + * requiring service dependencies. + * This validator handles validation that cannot be performed in the request + * object itself, + * such as template existence checking. + */ +@Log4j2 +public class MLAgentRegistrationValidator { + + private final ContextManagementTemplateService contextManagementTemplateService; + + public MLAgentRegistrationValidator(ContextManagementTemplateService contextManagementTemplateService) { + this.contextManagementTemplateService = contextManagementTemplateService; + } + + /** + * Validates an ML agent for registration, performing all necessary validation checks. + * This is the main validation entry point that orchestrates all validation steps. + * + * @param agent the ML agent to validate + * @param listener callback for validation result - onResponse(true) if valid, onFailure with exception if not + */ + public void validateAgentForRegistration(MLAgent agent, ActionListener listener) { + try { + log.debug("Starting agent registration validation for agent: {}", agent.getName()); + + // First, perform basic context management configuration validation + String configError = validateContextManagementConfiguration(agent); + if (configError != null) { + log.error("Agent registration validation failed - configuration error: {}", configError); + listener.onFailure(new IllegalArgumentException(configError)); + return; + } + + // If agent has a context management template reference, validate template access + if (agent.getContextManagementName() != null) { + validateContextManagementTemplateAccess(agent.getContextManagementName(), ActionListener.wrap(templateAccessValid -> { + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + }, templateAccessError -> { + log.error("Agent registration validation failed - template access error: {}", templateAccessError.getMessage()); + listener.onFailure(templateAccessError); + })); + } else { + // No template reference, validation is complete + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + } + } catch (Exception e) { + log.error("Unexpected error during agent registration validation", e); + listener.onFailure(new IllegalArgumentException("Agent validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management template access (following connector access validation pattern). + * This method checks if the template exists and if the user has access to it. + * + * @param templateName the context management template name to validate + * @param listener callback for validation result - onResponse(true) if accessible, onFailure with exception if not + */ + public void validateContextManagementTemplateAccess(String templateName, ActionListener listener) { + try { + log.debug("Validating context management template access: {}", templateName); + + contextManagementTemplateService.getTemplate(templateName, ActionListener.wrap(template -> { + log.debug("Context management template access validation passed: {}", templateName); + listener.onResponse(true); + }, exception -> { + log.error("Context management template access validation failed: {}", templateName, exception); + if (exception instanceof MLResourceNotFoundException) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + templateName)); + } else { + listener + .onFailure( + new IllegalArgumentException("Failed to validate context management template: " + exception.getMessage()) + ); + } + })); + } catch (Exception e) { + log.error("Unexpected error during context management template access validation", e); + listener.onFailure(new IllegalArgumentException("Context management template validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management configuration structure and requirements. + * This method performs comprehensive validation of context management settings. + * + * @param agent the ML agent to validate + * @return validation error message if invalid, null if valid + */ + public String validateContextManagementConfiguration(MLAgent agent) { + // Check for conflicting configuration (both name and inline config specified) + if (agent.getContextManagementName() != null && agent.getContextManagement() != null) { + return "Cannot specify both context_management_name and context_management"; + } + + // Validate context management template name if specified + if (agent.getContextManagementName() != null) { + String templateNameError = validateContextManagementTemplateName(agent.getContextManagementName()); + if (templateNameError != null) { + return templateNameError; + } + } + + // Validate inline context management configuration if specified + if (agent.getContextManagement() != null) { + String inlineConfigError = validateInlineContextManagementConfiguration(agent.getContextManagement()); + if (inlineConfigError != null) { + return inlineConfigError; + } + } + + return null; // Valid + } + + /** + * Validates the context management template name format and basic requirements. + * + * @param templateName the template name to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementTemplateName(String templateName) { + if (templateName == null || templateName.trim().isEmpty()) { + return "Context management template name cannot be null or empty"; + } + + if (templateName.length() > 256) { + return "Context management template name cannot exceed 256 characters"; + } + + if (!templateName.matches("^[a-zA-Z0-9_\\-\\.]+$")) { + return "Context management template name can only contain letters, numbers, underscores, hyphens, and dots"; + } + + return null; // Valid + } + + /** + * Validates the inline context management configuration structure and content. + * + * @param contextManagement the context management configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateInlineContextManagementConfiguration( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement + ) { + // Use the built-in validation from ContextManagementTemplate + if (!contextManagement.isValid()) { + return "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations"; + } + + // Additional validation for specific requirements + if (contextManagement.getName() == null || contextManagement.getName().trim().isEmpty()) { + return "Context management configuration name cannot be null or empty"; + } + + if (contextManagement.getHooks() == null || contextManagement.getHooks().isEmpty()) { + return "Context management configuration must define at least one hook"; + } + + // Validate hook names and configurations + return validateContextManagementHooks(contextManagement.getHooks()); + } + + /** + * Validates context management hooks configuration. + * + * @param hooks the hooks configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementHooks( + java.util.Map> hooks + ) { + // Define valid hook names + java.util.Set validHookNames = java.util.Set + .of("PRE_TOOL", "POST_TOOL", "PRE_LLM", "POST_LLM", "PRE_EXECUTION", "POST_EXECUTION"); + + for (java.util.Map.Entry> entry : hooks + .entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + // Validate hook name + if (!validHookNames.contains(hookName)) { + return "Invalid hook name: " + hookName + ". Valid hook names are: " + validHookNames; + } + + // Validate hook configurations + if (configs == null || configs.isEmpty()) { + return "Hook " + hookName + " must have at least one context manager configuration"; + } + + for (int i = 0; i < configs.size(); i++) { + org.opensearch.ml.common.contextmanager.ContextManagerConfig config = configs.get(i); + if (!config.isValid()) { + return "Invalid context manager configuration at index " + + i + + " in hook " + + hookName + + ": type cannot be null or empty"; + } + + // Validate context manager type + if (config.getType() != null) { + String typeError = validateContextManagerType(config.getType(), hookName, i); + if (typeError != null) { + return typeError; + } + } + } + } + + return null; // Valid + } + + /** + * Validates context manager type for known types. + * + * @param type the context manager type to validate + * @param hookName the hook name for error reporting + * @param index the configuration index for error reporting + * @return validation error message if invalid, null if valid + */ + private String validateContextManagerType(String type, String hookName, int index) { + // Define known context manager types + java.util.Set knownTypes = java.util.Set + .of("ToolsOutputTruncateManager", "SummarizationManager", "MemoryManager", "ConversationManager"); + + // For now, we'll allow unknown types to provide flexibility for future context + // manager types + // This provides extensibility while still validating known ones + if (!knownTypes.contains(type)) { + log + .debug( + "Unknown context manager type '{}' in hook '{}' at index {}. This may be a custom or future type.", + type, + hookName, + index + ); + } + + return null; // Valid - we allow unknown types for extensibility + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index e91d27c9bb..f4d3597f3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -26,6 +26,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.agent.MLAgentRegistrationValidator; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -56,6 +58,7 @@ public class TransportRegisterAgentAction extends HandledTransportAction listener) { + // Validate context management configuration (following connector pattern) + if (agent.hasContextManagementTemplate()) { + // Validate context management template access (similar to connector access validation) + String templateName = agent.getContextManagementTemplateName(); + agentRegistrationValidator.validateContextManagementTemplateAccess(templateName, ActionListener.wrap(hasAccess -> { + if (Boolean.TRUE.equals(hasAccess)) { + continueAgentRegistration(agent, listener); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the context management template provided, template name: " + templateName + ) + ); + } + }, e -> { + log.error("You don't have permission to use the context management template provided, template name: {}", templateName, e); + listener.onFailure(e); + })); + } else if (agent.getInlineContextManagement() != null) { + // Validate inline context management configuration only if it exists (similar to inline connector validation) + validateInlineContextManagement(agent); + continueAgentRegistration(agent, listener); + } else { + // No context management configuration - that's fine, continue with registration + continueAgentRegistration(agent, listener); + } + } + + private void validateInlineContextManagement(MLAgent agent) { + if (agent.getInlineContextManagement() == null) { + log + .error( + "You must provide context management content when creating an agent without providing context management template name!" + ); + throw new IllegalArgumentException( + "You must provide context management content when creating an agent without context management template name!" + ); + } + + // Validate inline context management configuration structure + if (!agent.getInlineContextManagement().isValid()) { + log + .error( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + throw new IllegalArgumentException( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + } + } + + private void continueAgentRegistration(MLAgent agent, ActionListener listener) { String mcpConnectorConfigJSON = (agent.getParameters() != null) ? agent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { // MCP connector provided as tools but MCP feature is disabled, so abort. diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java new file mode 100644 index 0000000000..8ecc0a0e2f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for managing context management template indices. + * Handles index creation, mapping definition, and settings configuration. + */ +@Log4j2 +public class ContextManagementIndexUtils { + + public static final String CONTEXT_MANAGEMENT_TEMPLATES_INDEX = "ml_context_management_templates"; + + private final Client client; + private final ClusterService clusterService; + + public ContextManagementIndexUtils(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + /** + * Check if the context management templates index exists + * @return true if the index exists, false otherwise + */ + public boolean doesIndexExist() { + return clusterService.state().metadata().hasIndex(CONTEXT_MANAGEMENT_TEMPLATES_INDEX); + } + + /** + * Create the context management templates index if it doesn't exist + * @param listener ActionListener for the response + */ + public void createIndexIfNotExists(ActionListener listener) { + if (doesIndexExist()) { + log.debug("Context management templates index already exists"); + listener.onResponse(true); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + CreateIndexRequest createIndexRequest = new CreateIndexRequest(CONTEXT_MANAGEMENT_TEMPLATES_INDEX).settings(getIndexSettings()); + + client.admin().indices().create(createIndexRequest, ActionListener.wrap(createIndexResponse -> { + log.info("Successfully created context management templates index"); + wrappedListener.onResponse(true); + }, exception -> { + if (exception instanceof org.opensearch.ResourceAlreadyExistsException) { + log.debug("Context management templates index already exists"); + wrappedListener.onResponse(true); + } else { + log.error("Failed to create context management templates index", exception); + wrappedListener.onFailure(exception); + } + })); + } catch (Exception e) { + log.error("Error creating context management templates index", e); + listener.onFailure(e); + } + } + + /** + * Get the index settings for context management templates + * @return Settings for the index + */ + private Settings getIndexSettings() { + return Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.auto_expand_replicas", "0-1") + .build(); + } + + /** + * Get the index name for context management templates + * @return The index name + */ + public static String getIndexName() { + return CONTEXT_MANAGEMENT_TEMPLATES_INDEX; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java new file mode 100644 index 0000000000..d754375688 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java @@ -0,0 +1,316 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Service for managing context management templates in OpenSearch. + * Provides CRUD operations for storing and retrieving context management configurations. + */ +@Log4j2 +public class ContextManagementTemplateService { + + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final ClusterService clusterService; + private final ContextManagementIndexUtils indexUtils; + + @Inject + public ContextManagementTemplateService(MLIndicesHandler mlIndicesHandler, Client client, ClusterService clusterService) { + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.clusterService = clusterService; + this.indexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + /** + * Save a context management template to OpenSearch + * @param templateName The name of the template + * @param template The template to save + * @param listener ActionListener for the response + */ + public void saveTemplate(String templateName, ContextManagementTemplate template, ActionListener listener) { + try { + // Validate template + if (!template.isValid()) { + listener.onFailure(new IllegalArgumentException("Invalid context management template")); + return; + } + + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + // Set timestamps + Instant now = Instant.now(); + if (template.getCreatedTime() == null) { + template.setCreatedTime(now); + } + template.setLastModified(now); + + // Set created by if not already set + if (template.getCreatedBy() == null && user != null) { + template.setCreatedBy(user.getName()); + } + + // Ensure index exists first + indexUtils.createIndexIfNotExists(ActionListener.wrap(indexCreated -> { + // Check if template with same name already exists + validateUniqueTemplateName(template.getName(), ActionListener.wrap(exists -> { + if (exists) { + wrappedListener + .onFailure( + new IllegalArgumentException( + "A context management template with name '" + template.getName() + "' already exists" + ) + ); + return; + } + + // Create the index request with proper JSON serialization + IndexRequest indexRequest = new IndexRequest(ContextManagementIndexUtils.getIndexName()) + .id(template.getName()) + .source(template.toXContent(jsonXContent.contentBuilder(), ToXContentObject.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // Execute the index operation + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("Context management template saved successfully: {}", template.getName()); + wrappedListener.onResponse(true); + }, exception -> { + log.error("Failed to save context management template: {}", template.getName(), exception); + wrappedListener.onFailure(exception); + })); + }, wrappedListener::onFailure)); + }, wrappedListener::onFailure)); + } + } catch (Exception e) { + log.error("Error saving context management template", e); + listener.onFailure(e); + } + } + + /** + * Get a context management template by name + * @param templateName The name of the template to retrieve + * @param listener ActionListener for the response + */ + public void getTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + GetRequest getRequest = new GetRequest(ContextManagementIndexUtils.getIndexName(), templateName); + + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + return; + } + + try { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsBytesRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + wrappedListener.onResponse(template); + } catch (Exception e) { + log.error("Failed to parse context management template: {}", templateName, e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + } else { + log.error("Failed to get context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error getting context management template", e); + listener.onFailure(e); + } + } + + /** + * List all context management templates + * @param listener ActionListener for the response + */ + public void listTemplates(ActionListener> listener) { + listTemplates(0, 1000, listener); + } + + /** + * List context management templates with pagination + * @param from Starting index for pagination + * @param size Number of templates to return + * @param listener ActionListener for the response + */ + public void listTemplates(int from, int size, ActionListener> listener) { + try { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener> wrappedListener = ActionListener.runBefore(listener, context::restore); + + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new MatchAllQueryBuilder()).from(from).size(size); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + try { + List templates = new java.util.ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + templates.add(template); + } + wrappedListener.onResponse(templates); + } catch (Exception e) { + log.error("Failed to parse context management templates", e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Return empty list if index doesn't exist + wrappedListener.onResponse(new java.util.ArrayList<>()); + } else { + log.error("Failed to list context management templates", exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error listing context management templates", e); + listener.onFailure(e); + } + } + + /** + * Delete a context management template by name + * @param templateName The name of the template to delete + * @param listener ActionListener for the response + */ + public void deleteTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + DeleteRequest deleteRequest = new DeleteRequest(ContextManagementIndexUtils.getIndexName(), templateName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.delete(deleteRequest, ActionListener.wrap(deleteResponse -> { + boolean deleted = deleteResponse.getResult() == DeleteResponse.Result.DELETED; + if (deleted) { + log.info("Context management template deleted successfully: {}", templateName); + } else { + log.warn("Context management template not found for deletion: {}", templateName); + } + wrappedListener.onResponse(deleted); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener.onResponse(false); + } else { + log.error("Failed to delete context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error deleting context management template", e); + listener.onFailure(e); + } + } + + /** + * Validate that a template name is unique + * @param templateName The template name to check + * @param listener ActionListener for the response (true if exists, false if unique) + */ + private void validateUniqueTemplateName(String templateName, ActionListener listener) { + try { + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new TermQueryBuilder("_id", templateName)).size(1); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + boolean exists = searchResponse.getHits().getTotalHits() != null && searchResponse.getHits().getTotalHits().value() > 0; + listener.onResponse(exists); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Index doesn't exist, so template name is unique + listener.onResponse(false); + } else { + listener.onFailure(exception); + } + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Create XContentParser from registry - utility method + */ + private XContentParser createXContentParserFromRegistry( + NamedXContentRegistry xContentRegistry, + LoggingDeprecationHandler deprecationHandler, + org.opensearch.core.common.bytes.BytesReference bytesReference + ) throws java.io.IOException { + return MediaTypeRegistry.JSON.xContent().createParser(xContentRegistry, deprecationHandler, bytesReference.streamInput()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java new file mode 100644 index 0000000000..86b26b4d6b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import java.util.Map; + +import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory for creating context manager instances from configuration. + * This factory creates the appropriate context manager based on the type + * specified in the configuration and initializes it with the provided settings. + */ +@Log4j2 +public class ContextManagerFactory { + + private final ActivationRuleFactory activationRuleFactory; + private final Client client; + + @Inject + public ContextManagerFactory(ActivationRuleFactory activationRuleFactory, Client client) { + this.activationRuleFactory = activationRuleFactory; + this.client = client; + } + + /** + * Create a context manager instance from configuration + * @param config The context manager configuration + * @return The created context manager instance + * @throws IllegalArgumentException if the manager type is not supported + */ + public ContextManager createContextManager(ContextManagerConfig config) { + if (config == null || config.getType() == null) { + throw new IllegalArgumentException("Context manager configuration and type cannot be null"); + } + + String type = config.getType(); + Map managerConfig = config.getConfig(); + Map activationConfig = config.getActivation(); + + log.debug("Creating context manager of type: {}", type); + + ContextManager manager; + switch (type) { + case "ToolsOutputTruncateManager": + manager = createToolsOutputTruncateManager(managerConfig); + break; + case "SlidingWindowManager": + manager = createSlidingWindowManager(managerConfig); + break; + case "SummarizationManager": + manager = createSummarizationManager(managerConfig); + break; + default: + throw new IllegalArgumentException("Unsupported context manager type: " + type); + } + + // Initialize the manager with configuration + try { + // Merge activation and manager config for initialization + Map fullConfig = new java.util.HashMap<>(); + if (managerConfig != null) { + fullConfig.putAll(managerConfig); + } + if (activationConfig != null) { + fullConfig.put("activation", activationConfig); + } + + manager.initialize(fullConfig); + log.debug("Successfully created and initialized context manager: {}", type); + return manager; + } catch (Exception e) { + log.error("Failed to initialize context manager of type: {}", type, e); + throw new RuntimeException("Failed to initialize context manager: " + type, e); + } + } + + /** + * Create a ToolsOutputTruncateManager instance + */ + private ContextManager createToolsOutputTruncateManager(Map config) { + return new ToolsOutputTruncateManager(); + } + + /** + * Create a SlidingWindowManager instance + */ + private ContextManager createSlidingWindowManager(Map config) { + return new SlidingWindowManager(); + } + + /** + * Create a SummarizationManager instance + */ + private ContextManager createSummarizationManager(Map config) { + return new SummarizationManager(client); + } + + // Add more factory methods for other context manager types as they are implemented + + // private ContextManager createSummarizingManager(Map config) { + // return new SummarizingManager(); + // } + + // private ContextManager createSystemPromptAugmentationManager(Map config) { + // return new SystemPromptAugmentationManager(); + // } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..d5377d1bf1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public CreateContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLCreateContextManagementTemplateAction.NAME, transportService, actionFilters, MLCreateContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLCreateContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Creating context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.saveTemplate(request.getTemplateName(), request.getTemplate(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully created context management template: {}", request.getTemplateName()); + listener.onResponse(new MLCreateContextManagementTemplateResponse(request.getTemplateName(), "created")); + } else { + log.error("Failed to create context management template: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Failed to create context management template")); + } + }, exception -> { + log.error("Error creating context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error creating context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..6c025bc927 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public DeleteContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLDeleteContextManagementTemplateAction.NAME, transportService, actionFilters, MLDeleteContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLDeleteContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Deleting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.deleteTemplate(request.getTemplateName(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully deleted context management template: {}", request.getTemplateName()); + listener.onResponse(new MLDeleteContextManagementTemplateResponse(request.getTemplateName(), "deleted")); + } else { + log.warn("Context management template not found for deletion: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error deleting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error deleting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..011b6852c0 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public GetContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLGetContextManagementTemplateAction.NAME, transportService, actionFilters, MLGetContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLGetContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Getting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.getTemplate(request.getTemplateName(), ActionListener.wrap(template -> { + if (template != null) { + log.debug("Successfully retrieved context management template: {}", request.getTemplateName()); + listener.onResponse(new MLGetContextManagementTemplateResponse(template)); + } else { + log.warn("Context management template not found: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error getting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error getting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java new file mode 100644 index 0000000000..7667ac6cc6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ListContextManagementTemplatesTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public ListContextManagementTemplatesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLListContextManagementTemplatesAction.NAME, transportService, actionFilters, MLListContextManagementTemplatesRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLListContextManagementTemplatesRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Listing context management templates from: {} size: {}", request.getFrom(), request.getSize()); + + contextManagementTemplateService.listTemplates(request.getFrom(), request.getSize(), ActionListener.wrap(templates -> { + log.debug("Successfully retrieved {} context management templates", templates.size()); + // For now, return the size as total. In a real implementation, you'd get the actual total count + listener.onResponse(new MLListContextManagementTemplatesResponse(templates, templates.size())); + }, exception -> { + log.error("Error listing context management templates", exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error listing context management templates", e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 62de34961e..b25ed1afba 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -89,6 +89,12 @@ import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; +import org.opensearch.ml.action.contextmanagement.CreateContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.DeleteContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.GetContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.ListContextManagementTemplatesTransportAction; import org.opensearch.ml.action.controller.CreateControllerTransportAction; import org.opensearch.ml.action.controller.DeleteControllerTransportAction; import org.opensearch.ml.action.controller.DeployControllerTransportAction; @@ -191,6 +197,10 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; @@ -320,15 +330,16 @@ import org.opensearch.ml.processor.MLInferenceIngestProcessor; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLAddMemoriesAction; import org.opensearch.ml.rest.RestMLCancelBatchJobAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLCreateContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLCreateMemoryContainerAction; import org.opensearch.ml.rest.RestMLCreateSessionAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLDeleteControllerAction; import org.opensearch.ml.rest.RestMLDeleteMemoriesByQueryAction; import org.opensearch.ml.rest.RestMLDeleteMemoryAction; @@ -342,6 +353,7 @@ import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConfigAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; +import org.opensearch.ml.rest.RestMLGetContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLGetControllerAction; import org.opensearch.ml.rest.RestMLGetIndexInsightAction; import org.opensearch.ml.rest.RestMLGetIndexInsightConfigAction; @@ -351,6 +363,7 @@ import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLGetToolAction; +import org.opensearch.ml.rest.RestMLListContextManagementTemplatesAction; import org.opensearch.ml.rest.RestMLListToolsAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLPredictionStreamAction; @@ -448,6 +461,7 @@ import org.opensearch.watcher.ResourceWatcherService; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import lombok.SneakyThrows; @@ -616,7 +630,11 @@ public MachineLearningPlugin() {} new ActionHandler<>(MLMcpToolsListAction.INSTANCE, TransportMcpToolsListAction.class), new ActionHandler<>(MLMcpToolsUpdateAction.INSTANCE, TransportMcpToolsUpdateAction.class), new ActionHandler<>(MLMcpToolsUpdateOnNodesAction.INSTANCE, TransportMcpToolsUpdateOnNodesAction.class), - new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class) + new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class), + new ActionHandler<>(MLCreateContextManagementTemplateAction.INSTANCE, CreateContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLGetContextManagementTemplateAction.INSTANCE, GetContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLListContextManagementTemplatesAction.INSTANCE, ListContextManagementTemplatesTransportAction.class), + new ActionHandler<>(MLDeleteContextManagementTemplateAction.INSTANCE, DeleteContextManagementTemplateTransportAction.class) ); } @@ -784,6 +802,17 @@ public Collection createComponents( nodeHelper, mlEngine ); + // Create context management services + ContextManagementTemplateService contextManagementTemplateService = new ContextManagementTemplateService( + mlIndicesHandler, + client, + clusterService + ); + ContextManagerFactory contextManagerFactory = new ContextManagerFactory( + new org.opensearch.ml.common.contextmanager.ActivationRuleFactory(), + client + ); + mlExecuteTaskRunner = new MLExecuteTaskRunner( threadPool, clusterService, @@ -794,7 +823,9 @@ public Collection createComponents( mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ); // Register thread-safe ML objects here. @@ -1070,6 +1101,15 @@ public List getRestHandlers( RestMLMcpToolsRemoveAction restMLRemoveMcpToolsAction = new RestMLMcpToolsRemoveAction(clusterService, mlFeatureEnabledSetting); RestMLMcpToolsListAction restMLListMcpToolsAction = new RestMLMcpToolsListAction(mlFeatureEnabledSetting); RestMLMcpToolsUpdateAction restMLMcpToolsUpdateAction = new RestMLMcpToolsUpdateAction(clusterService, mlFeatureEnabledSetting); + RestMLCreateContextManagementTemplateAction restMLCreateContextManagementTemplateAction = + new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + RestMLGetContextManagementTemplateAction restMLGetContextManagementTemplateAction = new RestMLGetContextManagementTemplateAction( + mlFeatureEnabledSetting + ); + RestMLListContextManagementTemplatesAction restMLListContextManagementTemplatesAction = + new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + RestMLDeleteContextManagementTemplateAction restMLDeleteContextManagementTemplateAction = + new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); return ImmutableList .of( restMLStatsAction, @@ -1146,7 +1186,11 @@ public List getRestHandlers( restMLListMcpToolsAction, restMLMcpToolsUpdateAction, restMLPutIndexInsightConfigAction, - restMLGetIndexInsightConfigAction + restMLGetIndexInsightConfigAction, + restMLCreateContextManagementTemplateAction, + restMLGetContextManagementTemplateAction, + restMLListContextManagementTemplatesAction, + restMLDeleteContextManagementTemplateAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..34387ccb81 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLCreateContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_create_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLCreateContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCreateContextManagementTemplateRequest createRequest = getRequest(request); + return channel -> client + .execute(MLCreateContextManagementTemplateAction.INSTANCE, createRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLCreateContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLCreateContextManagementTemplateRequest + */ + @VisibleForTesting + MLCreateContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + + // Set the template name from URL parameter + template = template.toBuilder().name(templateName).build(); + + return new MLCreateContextManagementTemplateRequest(templateName, template); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..1dbde7f216 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLDeleteContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_delete_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLDeleteContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLDeleteContextManagementTemplateRequest deleteRequest = getRequest(request); + return channel -> client + .execute(MLDeleteContextManagementTemplateAction.INSTANCE, deleteRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLDeleteContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLDeleteContextManagementTemplateRequest + */ + @VisibleForTesting + MLDeleteContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLDeleteContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4089d0aac1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_get_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLGetContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLGetContextManagementTemplateRequest getRequest = getRequest(request); + return channel -> client.execute(MLGetContextManagementTemplateAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLGetContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLGetContextManagementTemplateRequest + */ + @VisibleForTesting + MLGetContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLGetContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..d5020bd5c3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLListContextManagementTemplatesAction extends BaseRestHandler { + private static final String ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION = "ml_list_context_management_templates_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLListContextManagementTemplatesAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/context_management", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLListContextManagementTemplatesRequest listRequest = getRequest(request); + return channel -> client + .execute(MLListContextManagementTemplatesAction.INSTANCE, listRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLListContextManagementTemplatesRequest from a RestRequest + * + * @param request RestRequest + * @return MLListContextManagementTemplatesRequest + */ + @VisibleForTesting + MLListContextManagementTemplatesRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + int from = request.paramAsInt("from", 0); + int size = request.paramAsInt("size", 10); + + return new MLListContextManagementTemplatesRequest(from, size); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 73281e0333..436c659b9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -15,10 +15,18 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.contextmanager.ContextManagerHookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; @@ -50,6 +58,8 @@ public class MLExecuteTaskRunner extends MLTaskRunner wrappedListener = ActionListener.runBefore(listener, ) Input input = request.getInput(); FunctionName functionName = request.getFunctionName(); + + // Handle agent execution with context management + if (FunctionName.AGENT.equals(functionName) && input instanceof AgentMLInput) { + AgentMLInput agentInput = (AgentMLInput) input; + String contextManagementName = getEffectiveContextManagementName(agentInput); + + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + // Execute agent with context management + executeAgentWithContextManagement(request, contextManagementName, channel, listener); + return; + } + } + if (FunctionName.METRICS_CORRELATION.equals(functionName)) { if (!isPythonModelEnabled) { Exception exception = new IllegalArgumentException("This algorithm is not enabled from settings"); @@ -163,10 +189,17 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener { - MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); - listener.onResponse(response); - }, e -> { listener.onFailure(e); }), channel); + + // Default execution for all functions (including agents without context management) + try { + mlEngine.execute(input, ActionListener.wrap(output -> { + MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); + listener.onResponse(response); + }, e -> { listener.onFailure(e); }), channel); + } catch (Exception e) { + log.error("Failed to execute ML function", e); + listener.onFailure(e); + } } catch (Exception e) { mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) @@ -178,4 +211,218 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener + ) { + log.debug("Executing agent with context management: {}", contextManagementName); + + // Lookup context management template + contextManagementTemplateService.getTemplate(contextManagementName, ActionListener.wrap(template -> { + if (template == null) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + contextManagementName)); + return; + } + + try { + // Create context managers from template + java.util.List contextManagers = createContextManagers(template); + + // Create HookRegistry with context managers + HookRegistry hookRegistry = createHookRegistry(contextManagers, template); + + // Set hook registry in agent input + AgentMLInput agentInput = (AgentMLInput) request.getInput(); + agentInput.setHookRegistry(hookRegistry); + + log + .info( + "Executing agent with context management template: {} using {} context managers", + contextManagementName, + contextManagers.size() + ); + + // Execute agent with hook registry + try { + mlEngine.execute(request.getInput(), ActionListener.wrap(output -> { + log.info("Agent execution completed successfully with context management"); + MLExecuteTaskResponse response = new MLExecuteTaskResponse(request.getFunctionName(), output); + listener.onResponse(response); + }, error -> { + log.error("Agent execution failed with context management", error); + listener.onFailure(error); + }), channel); + } catch (Exception e) { + log.error("Failed to execute agent with context management", e); + listener.onFailure(e); + } + + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", contextManagementName, e); + listener.onFailure(e); + } + }, error -> { + log.error("Failed to retrieve context management template: {}", contextManagementName, error); + listener.onFailure(error); + })); + } + + /** + * Gets the effective context management name for an agent. + * Priority: 1) Runtime parameter from execution request, 2) Agent's stored configuration, 3) Runtime parameters set by MLAgentExecutor + * This follows the same pattern as MCP connectors. + * + * @param agentInput the agent ML input + * @return the effective context management name, or null if none configured + */ + private String getEffectiveContextManagementName(AgentMLInput agentInput) { + // Priority 1: Runtime parameter from execution request (user override) + String runtimeContextManagementName = agentInput.getContextManagementName(); + if (runtimeContextManagementName != null && !runtimeContextManagementName.trim().isEmpty()) { + log.debug("Using runtime context management name: {}", runtimeContextManagementName); + return runtimeContextManagementName; + } + + // Priority 2: Check agent's stored configuration directly + String agentId = agentInput.getAgentId(); + if (agentId != null) { + try { + // Use a blocking call to get the agent synchronously + // This is acceptable here since we're in the task execution path + java.util.concurrent.CompletableFuture future = new java.util.concurrent.CompletableFuture<>(); + + try ( + org.opensearch.common.util.concurrent.ThreadContext.StoredContext context = client + .threadPool() + .getThreadContext() + .stashContext() + ) { + client + .get( + new org.opensearch.action.get.GetRequest(org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX, agentId), + org.opensearch.core.action.ActionListener.runBefore(org.opensearch.core.action.ActionListener.wrap(response -> { + if (response.isExists()) { + try { + org.opensearch.core.xcontent.XContentParser parser = + org.opensearch.common.xcontent.json.JsonXContent.jsonXContent + .createParser( + null, + org.opensearch.common.xcontent.LoggingDeprecationHandler.INSTANCE, + response.getSourceAsString() + ); + org.opensearch.core.xcontent.XContentParserUtils + .ensureExpectedToken( + org.opensearch.core.xcontent.XContentParser.Token.START_OBJECT, + parser.nextToken(), + parser + ); + org.opensearch.ml.common.agent.MLAgent mlAgent = org.opensearch.ml.common.agent.MLAgent + .parse(parser); + + if (mlAgent.hasContextManagementTemplate()) { + String templateName = mlAgent.getContextManagementTemplateName(); + future.complete(templateName); + } else { + future.complete(null); + } + } catch (Exception e) { + future.completeExceptionally(e); + } + } else { + future.complete(null); // Agent not found + } + }, future::completeExceptionally), context::restore) + ); + } + + // Wait for the result with a timeout + String contextManagementName = future.get(5, java.util.concurrent.TimeUnit.SECONDS); + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + return contextManagementName; + } + } catch (Exception e) { + // Continue to fallback methods + } + } + + // Priority 3: Agent's runtime parameters (set by MLAgentExecutor in input parameters) + if (agentInput.getInputDataset() instanceof org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) { + org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet dataset = + (org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + // Check if context management has already been processed by MLAgentExecutor (for inline templates) + String contextManagementProcessed = dataset.getParameters().get("context_management_processed"); + if ("true".equals(contextManagementProcessed)) { + log.debug("Context management already processed by MLAgentExecutor, skipping MLExecuteTaskRunner processing"); + return null; // Skip processing in MLExecuteTaskRunner + } + + // Handle template references (not processed by MLAgentExecutor) + String agentContextManagementName = dataset.getParameters().get("context_management"); + if (agentContextManagementName != null && !agentContextManagementName.trim().isEmpty()) { + return agentContextManagementName; + } + } + + return null; + } + + /** + * Create context managers from template configuration + */ + private java.util.List createContextManagers(ContextManagementTemplate template) { + java.util.List contextManagers = new java.util.ArrayList<>(); + + // Iterate through all hooks in the template + for (java.util.Map.Entry> entry : template.getHooks().entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + try { + ContextManager manager = contextManagerFactory.createContextManager(config); + if (manager != null) { + contextManagers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } else { + log.warn("Failed to create context manager of type: {}", config.getType()); + } + } catch (Exception e) { + log.error("Error creating context manager of type: {}", config.getType(), e); + // Continue with other managers + } + } + } + + log.info("Created {} context managers from template: {}", contextManagers.size(), template.getName()); + return contextManagers; + } + + /** + * Create HookRegistry with context managers + */ + private HookRegistry createHookRegistry(java.util.List contextManagers, ContextManagementTemplate template) { + HookRegistry hookRegistry = new HookRegistry(); + + if (!contextManagers.isEmpty()) { + // Create context manager hook provider + ContextManagerHookProvider hookProvider = new ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks + hookProvider.registerHooks(hookRegistry); + + log.debug("Registered context manager hooks for {} managers", contextManagers.size()); + } + + return hookRegistry; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index e5c11215c3..acbd70bb43 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -76,6 +76,7 @@ public class RestActionUtils { public static final String[] UI_METADATA_EXCLUDE = new String[] { "ui_metadata" }; public static final String PARAMETER_TOOL_NAME = "tool_name"; + public static final String PARAMETER_TEMPLATE_NAME = "template_name"; public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java new file mode 100644 index 0000000000..24009b5094 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java @@ -0,0 +1,413 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import junit.framework.TestCase; + +public class MLAgentRegistrationValidatorTests extends TestCase { + + private ContextManagementTemplateService mockTemplateService; + private MLAgentRegistrationValidator validator; + + @Before + public void setUp() { + mockTemplateService = mock(ContextManagementTemplateService.class); + validator = new MLAgentRegistrationValidator(mockTemplateService); + } + + @Test + public void testValidateAgentForRegistration_NoContextManagement() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since no template reference + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateExists() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("existing_template").build(); + + // Mock template service to return a template (exists) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ContextManagementTemplate template = ContextManagementTemplate.builder().name("existing_template").build(); + listener.onResponse(template); + return null; + }).when(mockTemplateService).getTemplate(eq("existing_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + verify(mockTemplateService).getTemplate(eq("existing_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateNotFound() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("nonexistent_template").build(); + + // Mock template service to return template not found + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Context management template not found: nonexistent_template")); + return null; + }).when(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Context management template not found: nonexistent_template")); + + verify(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateServiceError() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("error_template").build(); + + // Mock template service to return an error + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Service error")); + return null; + }).when(mockTemplateService).getTemplate(eq("error_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Failed to validate context management template")); + + verify(mockTemplateService).getTemplate(eq("error_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_InlineContextManagement() throws InterruptedException { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since using inline configuration + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidTemplateName() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue( + error + .get() + .getMessage() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidInlineConfiguration() throws InterruptedException { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Invalid hook name: INVALID_HOOK")); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateContextManagementConfiguration_ValidTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("valid_template_name").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_ValidInlineConfig() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot be null or empty")); + } + + @Test + public void testValidateContextManagementConfiguration_TooLongTemplateName() { + String longName = "a".repeat(257); + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidTemplateNameCharacters() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidHookName() { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyHookConfigs() { + Map> emptyHooks = new HashMap<>(); + emptyHooks.put("POST_TOOL", List.of()); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(emptyHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Hook POST_TOOL must have at least one context manager configuration")); + } + + @Test + public void testValidateContextManagementConfiguration_Conflict() { + // This test should verify that the MLAgent constructor throws an exception + // when both context management name and inline config are provided + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + try { + MLAgent agent = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Cannot specify both context_management_name and context_management", e.getMessage()); + } + } + + @Test + public void testValidateContextManagementConfiguration_InvalidInlineConfig() { + // This test should verify that the MLAgent constructor throws an exception + // when invalid context management configuration is provided + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("invalid_template") + .hooks(new HashMap<>()) + .build(); + + try { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidContextManagement).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Invalid context management configuration", e.getMessage()); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java index 63fd5216e2..5ec7f38311 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -353,6 +353,8 @@ private GetResponse prepareMLAgent(String agentId, boolean isHidden, String tena Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 8a5e081855..7bed2a7225 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -335,6 +335,8 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenan Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java index b9e7323b7e..0818845cac 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -42,6 +42,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -96,6 +97,9 @@ public class RegisterAgentTransportActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -116,7 +120,8 @@ public void setup() throws IOException { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); indexResponse = new IndexResponse(new ShardId(ML_AGENT_INDEX, "_na_", 0), "AGENT_ID", 1, 0, 2, true); } @@ -510,7 +515,8 @@ public void test_execute_registerAgent_MCPConnectorDisabled() { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); disabledAction.doExecute(task, request, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java new file mode 100644 index 0000000000..c8d1c8953f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +public class ContextManagementIndexUtilsTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private AdminClient adminClient; + + @Mock + private IndicesAdminClient indicesAdminClient; + + private ContextManagementIndexUtils contextManagementIndexUtils; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + contextManagementIndexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + @Test + public void testGetIndexName() { + String indexName = ContextManagementIndexUtils.getIndexName(); + assertEquals("ml_context_management_templates", indexName); + } + + @Test + public void testDoesIndexExist_True() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertTrue(exists); + } + + @Test + public void testDoesIndexExist_False() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertFalse(exists); + } + + @Test + public void testCreateIndexIfNotExists_IndexAlreadyExists() { + // Mock index already exists + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + verify(indicesAdminClient, never()).create(any(), any()); + } + + @Test + public void testCreateIndexIfNotExists_Success() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock successful index creation + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + CreateIndexResponse response = mock(CreateIndexResponse.class); + createListener.onResponse(response); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + + // Verify the create request was made with correct settings + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(CreateIndexRequest.class); + verify(indicesAdminClient).create(requestCaptor.capture(), any()); + + CreateIndexRequest request = requestCaptor.getValue(); + assertEquals("ml_context_management_templates", request.index()); + + Settings indexSettings = request.settings(); + assertEquals("1", indexSettings.get("index.number_of_shards")); + assertEquals("1", indexSettings.get("index.number_of_replicas")); + assertEquals("0-1", indexSettings.get("index.auto_expand_replicas")); + } + + @Test + public void testCreateIndexIfNotExists_ResourceAlreadyExistsException() { + // Mock index doesn't exist initially + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock ResourceAlreadyExistsException (race condition) + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(new ResourceAlreadyExistsException("Index already exists")); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + } + + @Test + public void testCreateIndexIfNotExists_OtherException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Test exception"); + + // Mock other exception + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(testException); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } + + @Test + public void testCreateIndexIfNotExists_UnexpectedException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception during setup + when(client.admin()).thenThrow(testException); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java new file mode 100644 index 0000000000..c8b7391908 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java @@ -0,0 +1,351 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; + +public class ContextManagementTemplateServiceTests extends OpenSearchTestCase { + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + private ContextManagementTemplateService contextManagementTemplateService; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + // Mock cluster service dependencies for proper setup + org.opensearch.cluster.ClusterState clusterState = mock(org.opensearch.cluster.ClusterState.class); + org.opensearch.cluster.metadata.Metadata metadata = mock(org.opensearch.cluster.metadata.Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Default to index not existing + + contextManagementTemplateService = new ContextManagementTemplateService(mlIndicesHandler, client, clusterService); + } + + @Test + public void testConstructor() { + assertNotNull(contextManagementTemplateService); + } + + @Test + public void testSaveTemplate_InvalidTemplate() { + String templateName = "test_template"; + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate(templateName, template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Invalid context management template", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(5, 20, listener); + + // Verify that the method was called - the actual OpenSearch interaction would be complex to mock + // This at least exercises the method signature and basic flow + verify(client).threadPool(); + } + + @Test + public void testListTemplates_DefaultPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + // Verify that the method was called - this exercises the default pagination path + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_NullTemplate() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof NullPointerException); + } + + @Test + public void testSaveTemplate_ValidTemplate() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called - the method will fail due to complex mocking requirements + // but this covers the validation path and timestamp setting + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } + + @Test + public void testSaveTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenThrow(new RuntimeException("Validation error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("Validation error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPaginationExceptionInTryBlock() { + // Test exception handling in the outer try-catch block for paginated version + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(10, 50, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_NullListener() { + // This should not throw an exception, but we can test that the method handles it gracefully + try { + contextManagementTemplateService.listTemplates(null); + // If we get here without exception, that's fine - the method should handle null listeners gracefully + } catch (Exception e) { + // If an exception is thrown, it should be a meaningful one + assertTrue(e instanceof IllegalArgumentException || e instanceof NullPointerException); + } + } + + @Test + public void testGetTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testDeleteTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_TemplateWithExistingCreatedTime() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(java.time.Instant.now()); // Already has created time + when(template.getCreatedBy()).thenReturn("existing_user"); // Already has created by + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called and existing values were checked + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should call setLastModified but not setCreatedTime or setCreatedBy since they exist + verify(template).setLastModified(any(java.time.Instant.class)); + verify(template, never()).setCreatedTime(any(java.time.Instant.class)); + verify(template, never()).setCreatedBy(anyString()); + } + + @Test + public void testSaveTemplate_TemplateWithNullCreatedBy() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should set both created time and last modified + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java new file mode 100644 index 0000000000..f196da28d9 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +public class ContextManagerFactoryTests { + + private ContextManagerFactory contextManagerFactory; + private ActivationRuleFactory activationRuleFactory; + private Client client; + + @Before + public void setUp() { + activationRuleFactory = mock(ActivationRuleFactory.class); + client = mock(Client.class); + contextManagerFactory = new ContextManagerFactory(activationRuleFactory, client); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManagerWithParameters() { + // Arrange + Map parameters = Map.of("maxLength", 1000); + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", parameters, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_SlidingWindowManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SlidingWindowManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SlidingWindowManager); + } + + @Test + public void testCreateContextManager_SummarizationManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SummarizationManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SummarizationManager); + } + + @Test + public void testCreateContextManager_UnknownType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("UnknownManager", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for unknown manager type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } + + @Test + public void testCreateContextManager_NullConfig() { + // Act & Assert + try { + contextManagerFactory.createContextManager(null); + fail("Expected IllegalArgumentException for null config"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_NullType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig(null, null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for null type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_EmptyType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for empty type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } + + @Test + public void testContextManagerHookProvider_SelectiveRegistration() { + // Test that ContextManagerHookProvider only registers hooks for configured managers + java.util.Map> hookToManagersMap = new java.util.HashMap<>(); + + // Test 1: Only POST_TOOL configured + hookToManagersMap.put("POST_TOOL", java.util.Arrays.asList("ToolsOutputTruncateManager")); + + // Simulate the registration logic + java.util.Set registeredHooks = new java.util.HashSet<>(); + if (hookToManagersMap.containsKey("PRE_LLM")) { + registeredHooks.add("PRE_LLM"); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registeredHooks.add("POST_TOOL"); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registeredHooks.add("POST_MEMORY"); + } + + // Assert only POST_TOOL is registered + assertTrue("Should only register POST_TOOL hook", registeredHooks.size() == 1 && registeredHooks.contains("POST_TOOL")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..859cc1fbd7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class CreateContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private CreateContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new CreateContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLCreateContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("created", response.getStatus()); + } + + @Test + public void testDoExecute_SaveFailure() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Failed to create context management template", exception.getMessage()); + } + + @Test + public void testDoExecute_SaveException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException saveException = new RuntimeException("Database error"); + + // Mock exception during template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onFailure(saveException); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(saveException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).saveTemplate(any(), any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..a160fabc59 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class DeleteContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private DeleteContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new DeleteContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLDeleteContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("deleted", response.getStatus()); + } + + @Test + public void testDoExecute_DeleteFailure() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: test_template", exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).deleteTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..4bb1328518 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class GetContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private GetContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new GetContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLGetContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(template, response.getTemplate()); + } + + @Test + public void testDoExecute_TemplateNotFound() { + String templateName = "nonexistent_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock template not found (null response) + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(null); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: " + templateName, exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).getTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).getTemplate(eq(templateName), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java new file mode 100644 index 0000000000..c5951ee868 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class ListContextManagementTemplatesTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private ListContextManagementTemplatesTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new ListContextManagementTemplatesTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1"), createTestTemplate("template2")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + } + + @Test + public void testDoExecute_EmptyList() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List emptyTemplates = Collections.emptyList(); + + // Mock empty template list + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(emptyTemplates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(emptyTemplates, response.getTemplates()); + assertTrue(response.getTemplates().isEmpty()); + } + + @Test + public void testDoExecute_ServiceException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).listTemplates(anyInt(), anyInt(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + int from = 5; + int size = 20; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + @Test + public void testDoExecute_CustomPagination() { + int from = 10; + int size = 5; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template3")); + + // Mock successful template listing with custom pagination + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + + // Verify the service was called with custom pagination parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + private ContextManagementTemplate createTestTemplate(String name) { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name(name) + .description("Test template " + name) + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..ea585e0459 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLCreateContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLCreateContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCreateContextManagementTemplateAction action = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_create_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceOnlyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithInvalidJsonContent() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray("invalid json"), XContentType.JSON) + .build(); + + assertThrows(Exception.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLCreateContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + assertNotNull(result.getTemplate()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + } + + private String getValidTemplateContent() { + return "{\n" + + " \"description\": \"Test template\",\n" + + " \"hooks\": {\n" + + " \"PreLLMEvent\": [\n" + + " {\n" + + " \"type\": \"SummarizationManager\",\n" + + " \"config\": {\n" + + " \"summary_ratio\": 0.3,\n" + + " \"preserve_recent_messages\": 10\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..520dd05d19 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLDeleteContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLDeleteContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteContextManagementTemplateAction action = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLDeleteContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..abe0d3edaa --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLGetContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLGetContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetContextManagementTemplateAction action = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLGetContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java new file mode 100644 index 0000000000..d1f56a934a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLListContextManagementTemplatesActionTests extends OpenSearchTestCase { + private RestMLListContextManagementTemplatesAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLListContextManagementTemplatesAction action = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_list_context_management_templates_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + public void testPrepareRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(5, 20); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(5, capturedRequest.getFrom()); + assertEquals(20, capturedRequest.getSize()); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithDefaultPagination() throws Exception { + RestRequest request = getRestRequest(); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(0, result.getFrom()); + assertEquals(10, result.getSize()); + } + + public void testGetRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(15, 25); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(15, result.getFrom()); + assertEquals(25, result.getSize()); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithInvalidPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(-1, -5); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + // The REST action passes through the parameters as-is, validation happens at the service level + assertEquals(-1, result.getFrom()); + assertEquals(-5, result.getSize()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } + + private RestRequest getRestRequestWithPagination(int from, int size) { + Map params = new HashMap<>(); + params.put("from", String.valueOf(from)); + params.put("size", String.valueOf(size)); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 1f53744661..d1d0eca084 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -33,6 +33,8 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -78,6 +80,10 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { DiscoveryNodeHelper nodeHelper; @Mock ClusterApplierService clusterApplierService; + @Mock + ContextManagementTemplateService contextManagementTemplateService; + @Mock + ContextManagerFactory contextManagerFactory; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -138,7 +144,9 @@ public void setup() { mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ) ); From fed7e520d55b6fc59cd884f16bb481eb65f24fb0 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 18 Nov 2025 14:19:22 -0800 Subject: [PATCH 3/5] refactor memory interface; add agentic conversation memory (#4434) Signed-off-by: Yaliang Wu --- .../opensearch/ml/common/MLMemoryType.java | 25 + .../ml/common/agent/MLMemorySpec.java | 16 +- .../ml/common/conversation/Interaction.java | 12 +- .../opensearch/ml/common/memory/Memory.java | 57 ++ .../opensearch/ml/common}/memory/Message.java | 6 +- .../ml/engine/agents/AgentContextUtil.java | 2 +- .../engine/algorithms/agent/AgentUtils.java | 21 + .../algorithms/agent/MLAgentExecutor.java | 139 ++--- .../algorithms/agent/MLChatAgentRunner.java | 71 +-- .../MLConversationalFlowAgentRunner.java | 34 +- .../algorithms/agent/MLFlowAgentRunner.java | 39 +- .../MLPlanExecuteAndReflectAgentRunner.java | 73 ++- .../memory/AgenticConversationMemory.java | 567 ++++++++++++++++++ .../ml/engine/memory/AgenticMemoryConfig.java | 45 ++ .../ml/engine/memory/BaseMessage.java | 2 +- .../memory/ConversationIndexMemory.java | 90 +-- .../algorithms/agent/MLAgentExecutorTest.java | 2 +- .../agent/MLChatAgentRunnerTest.java | 154 ++++- .../agent/MLFlowAgentRunnerTest.java | 62 +- ...LPlanExecuteAndReflectAgentRunnerTest.java | 12 +- .../memory/AgenticConversationMemoryTest.java | 157 +++++ .../memory/ConversationIndexMemoryTest.java | 169 +++--- .../memory/TransportSearchMemoriesAction.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 7 +- .../ml/common/spi/memory/Memory.java | 64 -- 25 files changed, 1369 insertions(+), 459 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/MLMemoryType.java create mode 100644 common/src/main/java/org/opensearch/ml/common/memory/Memory.java rename {spi/src/main/java/org/opensearch/ml/common/spi => common/src/main/java/org/opensearch/ml/common}/memory/Message.java (74%) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java delete mode 100644 spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java new file mode 100644 index 0000000000..45b82db53d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.util.Locale; + +public enum MLMemoryType { + CONVERSATION_INDEX, + AGENTIC_MEMORY, + REMOTE_AGENTIC_MEMORY; + + public static MLMemoryType from(String value) { + if (value != null) { + try { + return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Memory type"); + } + } + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java index bba24db6c4..7476ad351c 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -26,32 +26,37 @@ public class MLMemorySpec implements ToXContentObject { public static final String MEMORY_TYPE_FIELD = "type"; public static final String WINDOW_SIZE_FIELD = "window_size"; public static final String SESSION_ID_FIELD = "session_id"; + public static final String MEMORY_CONTAINER_ID_FIELD = "memory_container_id"; private String type; @Setter private String sessionId; private Integer windowSize; + private String memoryContainerId; @Builder(toBuilder = true) - public MLMemorySpec(String type, String sessionId, Integer windowSize) { + public MLMemorySpec(String type, String sessionId, Integer windowSize, String memoryContainerId) { if (type == null) { throw new IllegalArgumentException("agent name is null"); } this.type = type; this.sessionId = sessionId; this.windowSize = windowSize; + this.memoryContainerId = memoryContainerId; } public MLMemorySpec(StreamInput input) throws IOException { type = input.readString(); sessionId = input.readOptionalString(); windowSize = input.readOptionalInt(); + memoryContainerId = input.readOptionalString(); } public void writeTo(StreamOutput out) throws IOException { out.writeString(type); out.writeOptionalString(sessionId); out.writeOptionalInt(windowSize); + out.writeOptionalString(memoryContainerId); } @Override @@ -64,6 +69,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (sessionId != null) { builder.field(SESSION_ID_FIELD, sessionId); } + if (memoryContainerId != null) { + builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + } builder.endObject(); return builder; } @@ -72,6 +80,7 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException { String type = null; String sessionId = null; Integer windowSize = null; + String memoryContainerId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -88,12 +97,15 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException { case WINDOW_SIZE_FIELD: windowSize = parser.intValue(); break; + case MEMORY_CONTAINER_ID_FIELD: + memoryContainerId = parser.text(); + break; default: parser.skipChildren(); break; } } - return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).build(); + return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).memoryContainerId(memoryContainerId).build(); } public static MLMemorySpec fromStream(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 5da68b0d07..19c6ee21df 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -28,6 +28,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.memory.Message; import org.opensearch.search.SearchHit; import lombok.AllArgsConstructor; @@ -39,7 +40,7 @@ */ @Builder @AllArgsConstructor -public class Interaction implements Writeable, ToXContentObject { +public class Interaction implements Writeable, ToXContentObject, Message { @Getter private String id; @@ -275,4 +276,13 @@ public String toString() { + "}"; } + @Override + public String getType() { + return ""; + } + + @Override + public String getContent() { + return ""; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/memory/Memory.java b/common/src/main/java/org/opensearch/ml/common/memory/Memory.java new file mode 100644 index 0000000000..9cd18deeae --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memory/Memory.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memory; + +import java.util.List; +import java.util.Map; + +import org.opensearch.core.action.ActionListener; + +/** + * A general memory interface. + * @param Message type + * @param Save response type + * @param Update response type + */ +public interface Memory { + + /** + * Get memory type. + * @return memory type + */ + String getType(); + + /** + * Get memory ID. + * @return memory ID + */ + String getId(); + + default void save(Message message, String parentId, Integer traceNum, String action) {} + + default void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) {} + + default void update(String messageId, Map updateContent, ActionListener updateListener) {} + + default void getMessages(int size, ActionListener> listener) {} + + /** + * Clear all memory. + */ + void clear(); + + void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener wrap); + + interface Factory { + /** + * Create an instance of this Memory. + * + * @param params Parameters for the memory + * @param listener Action listener for the memory creation action + */ + void create(Map params, ActionListener listener); + } +} diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java b/common/src/main/java/org/opensearch/ml/common/memory/Message.java similarity index 74% rename from spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java rename to common/src/main/java/org/opensearch/ml/common/memory/Message.java index 148cc769e3..d7ca18718c 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java +++ b/common/src/main/java/org/opensearch/ml/common/memory/Message.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.spi.memory; +package org.opensearch.ml.common.memory; /** * General message interface. @@ -12,13 +12,13 @@ public interface Message { /** * Get message type. - * @return + * @return message type */ String getType(); /** * Get message content. - * @return + * @return message content */ String getContent(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java index da2fd985f3..0715262e8c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -17,7 +17,7 @@ import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.hooks.PreLLMEvent; -import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.memory.ConversationIndexMemory; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..cbc003388f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.isJson; @@ -29,6 +30,7 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; import java.io.IOException; @@ -83,6 +85,7 @@ import org.opensearch.ml.engine.algorithms.remote.McpStreamableHttpConnectorExecutor; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.tools.McpSseTool; import org.opensearch.ml.engine.tools.McpStreamableHttpTool; import org.opensearch.remote.metadata.client.GetDataObjectRequest; @@ -1014,4 +1017,22 @@ public static Tool createTool(Map toolFactories, Map createMemoryParams( + String question, + String memoryId, + String appType, + MLAgent mlAgent, + String memoryContainerId + ) { + Map memoryParams = new HashMap<>(); + memoryParams.put(ConversationIndexMemory.MEMORY_NAME, question); + memoryParams.put(ConversationIndexMemory.MEMORY_ID, memoryId); + memoryParams.put(APP_TYPE, appType); + if (mlAgent.getMemory().getMemoryContainerId() != null) { + memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId()); + } + memoryParams.putIfAbsent(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + return memoryParams; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 39d3fe6263..382d7421a6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -13,9 +13,12 @@ import static org.opensearch.ml.common.MLTask.RESPONSE_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -56,6 +59,7 @@ import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.model.ModelTensor; @@ -63,7 +67,6 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.settings.SettingsChangeListener; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.Executable; import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; @@ -72,7 +75,6 @@ import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; import org.opensearch.ml.memory.action.conversation.GetInteractionAction; @@ -256,71 +258,74 @@ public void execute(Input input, ActionListener listener, TransportChann && memorySpec.getType() != null && memoryFactoryMap.containsKey(memorySpec.getType()) && (memoryId == null || parentInteractionId == null)) { - ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); - conversationIndexMemoryFactory - .create(question, memoryId, appType, ActionListener.wrap(memory -> { - inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // get question for regenerate - if (regenerateInteractionId != null) { - log.info("Regenerate for existing interaction {}", regenerateInteractionId); - client - .execute( - GetInteractionAction.INSTANCE, - new GetInteractionRequest(regenerateInteractionId), - ActionListener.wrap(interactionRes -> { - inputDataSet - .getParameters() - .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel, - hookRegistry - ); - }, e -> { - log.error("Failed to get existing interaction for regeneration", e); - listener.onFailure(e); - }) - ); - } else { - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel, - hookRegistry + Memory.Factory> memoryFactory = memoryFactoryMap.get(memorySpec.getType()); + + Map memoryParams = createMemoryParams( + question, + memoryId, + appType, + mlAgent, + inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) + ); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { + inputDataSet.getParameters().put(MEMORY_ID, memory.getId()); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel, + hookRegistry + ); + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) ); - } - }, ex -> { - log.error("Failed to read conversation memory", ex); - listener.onFailure(ex); - })); + } else { + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel, + hookRegistry + ); + } + }, ex -> { + log.error("Failed to read conversation memory", ex); + listener.onFailure(ex); + })); } else { // For existing conversations, create memory instance using factory if (memorySpec != null && memorySpec.getType() != null) { - ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); + Memory.Factory> factory = memoryFactoryMap.get(memorySpec.getType()); if (factory != null) { // memoryId exists, so create returns an object with existing // memory, therefore name can // be null factory .create( - null, - memoryId, - appType, + Map.of(MEMORY_ID, memoryId, APP_TYPE, appType), ActionListener .wrap( createdMemory -> executeAgent( @@ -394,7 +399,7 @@ public void execute(Input input, ActionListener listener, TransportChann */ private void saveRootInteractionAndExecute( ActionListener listener, - ConversationIndexMemory memory, + Memory memory, RemoteInferenceInputDataSet inputDataSet, MLTask mlTask, boolean isAsync, @@ -414,7 +419,7 @@ private void saveRootInteractionAndExecute( .question(question) .response("") .finalAnswer(true) - .sessionId(memory.getConversationId()) + .sessionId(memory.getId()) .build(); memory.save(msg, null, null, null, ActionListener.wrap(interaction -> { log.info("Created parent interaction ID: {}", interaction.getId()); @@ -422,7 +427,6 @@ private void saveRootInteractionAndExecute( // only delete previous interaction when new interaction created if (regenerateInteractionId != null) { memory - .getMemoryManager() .deleteInteractionAndTrace( regenerateInteractionId, ActionListener @@ -431,7 +435,7 @@ private void saveRootInteractionAndExecute( inputDataSet, mlTask, isAsync, - memory.getConversationId(), + memory.getId(), mlAgent, outputs, modelTensors, @@ -451,7 +455,7 @@ private void saveRootInteractionAndExecute( inputDataSet, mlTask, isAsync, - memory.getConversationId(), + memory.getId(), mlAgent, outputs, modelTensors, @@ -694,7 +698,7 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, TransportChannel channel, HookRegistry hookRegistry ) { @@ -781,7 +785,7 @@ private ActionListener createAgentActionListener( List modelTensors, String agentType, String parentInteractionId, - ConversationIndexMemory memory + Memory memory ) { return ActionListener.wrap(output -> { if (output != null) { @@ -802,7 +806,7 @@ private ActionListener createAsyncTaskUpdater( List outputs, List modelTensors, String parentInteractionId, - ConversationIndexMemory memory + Memory memory ) { String taskId = mlTask.getTaskId(); Map agentResponse = new HashMap<>(); @@ -959,15 +963,14 @@ public void indexMLTask(MLTask mlTask, ActionListener listener) { } } - private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) { + private void updateInteractionWithFailure(String interactionId, Memory memory, String errorMessage) { if (interactionId != null && memory != null) { String failureMessage = "Agent execution failed: " + errorMessage; Map updateContent = new HashMap<>(); updateContent.put(RESPONSE_FIELD, failureMessage); memory - .getMemoryManager() - .updateInteraction( + .update( interactionId, updateContent, ActionListener diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 6badaf31fa..d7c38f9ae3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -24,6 +25,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; @@ -63,11 +65,11 @@ import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; @@ -76,7 +78,6 @@ import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; import org.opensearch.ml.engine.function_calling.LLMMessage; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.remote.metadata.client.SdkClient; @@ -203,10 +204,12 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); - conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 - memory.getMessages(ActionListener.>wrap(r -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { String question = next.getInput(); @@ -221,7 +224,7 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener .add( ConversationIndexMessage .conversationIndexMessageBuilder() - .sessionId(memory.getConversationId()) + .sessionId(memory.getId()) .question(question) .response(response) .build() @@ -264,11 +267,11 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener } } - runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling); + runAgent(mlAgent, params, listener, memory, memory.getId(), functionCalling); }, e -> { log.error("Failed to get chat history", e); listener.onFailure(e); - }), messageHistoryLimit); + })); }, listener::onFailure)); } @@ -321,9 +324,6 @@ private void runReAct( boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false")); boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); - // Create root interaction. - ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; - // Trace number AtomicInteger traceNumber = new AtomicInteger(0); @@ -385,7 +385,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, finalAnswer @@ -409,7 +409,7 @@ private void runReAct( ); saveTraceData( - conversationIndexMemory, + memory, memory.getType(), question, thoughtResponse, @@ -429,7 +429,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, lastThought, @@ -490,7 +490,7 @@ private void runReAct( // Save trace with processed output saveTraceData( - conversationIndexMemory, + memory, "ReAct", lastActionInput.get(), outputToOutputString(filteredOutput), @@ -534,7 +534,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, lastThought, @@ -761,8 +761,8 @@ private static void updateParametersAcrossTools(Map tmpParameter } public static void saveTraceData( - ConversationIndexMemory conversationIndexMemory, - String memory, + Memory memory, + String memoryType, String question, String thoughtResponse, String sessionId, @@ -771,17 +771,17 @@ public static void saveTraceData( AtomicInteger traceNumber, String origin ) { - if (conversationIndexMemory != null) { + if (memory != null) { ConversationIndexMessage msgTemp = ConversationIndexMessage .conversationIndexMessageBuilder() - .type(memory) + .type(memoryType) .question(question) .response(thoughtResponse) .finalAnswer(false) .sessionId(sessionId) .build(); if (!traceDisabled) { - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), origin); + memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), origin); } } } @@ -794,7 +794,7 @@ private void sendFinalAnswer( boolean verbose, boolean traceDisabled, List cotModelTensors, - ConversationIndexMemory conversationIndexMemory, + Memory memory, AtomicInteger traceNumber, Map additionalInfo, String finalAnswer @@ -802,12 +802,11 @@ private void sendFinalAnswer( // Send completion chunk for streaming streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); - if (conversationIndexMemory != null) { + if (memory != null) { String copyOfFinalAnswer = finalAnswer; ActionListener saveTraceListener = ActionListener.wrap(r -> { - conversationIndexMemory - .getMemoryManager() - .updateInteraction( + memory + .update( parentInteractionId, Map.of(AI_RESPONSE_FIELD, copyOfFinalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.wrap(res -> { @@ -823,17 +822,7 @@ private void sendFinalAnswer( }, e -> { listener.onFailure(e); }) ); }, e -> { listener.onFailure(e); }); - saveMessage( - conversationIndexMemory, - question, - finalAnswer, - sessionId, - parentInteractionId, - traceNumber, - true, - traceDisabled, - saveTraceListener - ); + saveMessage(memory, question, finalAnswer, sessionId, parentInteractionId, traceNumber, true, traceDisabled, saveTraceListener); } else { streamingWrapper .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); @@ -963,7 +952,7 @@ private void handleMaxIterationsReached( boolean verbose, boolean traceDisabled, List traceTensors, - ConversationIndexMemory conversationIndexMemory, + Memory memory, AtomicInteger traceNumber, Map additionalInfo, AtomicReference lastThought, @@ -981,7 +970,7 @@ private void handleMaxIterationsReached( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, incompleteResponse @@ -990,7 +979,7 @@ private void handleMaxIterationsReached( } private void saveMessage( - ConversationIndexMemory memory, + Memory memory, String question, String finalAnswer, String sessionId, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 54d847b929..faeec6b050 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -9,12 +9,14 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.convertOutputToModelTensor; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; import static org.opensearch.ml.common.utils.ToolUtils.getToolName; import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; @@ -39,15 +41,14 @@ import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.remote.metadata.client.SdkClient; @@ -109,9 +110,11 @@ public void run(MLAgent mlAgent, Map params, ActionListener { - memory.getMessages(ActionListener.>wrap(r -> { + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { String question = next.getInput(); @@ -125,7 +128,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener { log.error("Failed to get chat history", e); listener.onFailure(e); - }), messageHistoryLimit); + })); }, listener::onFailure)); } @@ -153,7 +156,7 @@ private void runAgent( MLAgent mlAgent, Map params, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, String memoryId, String parentInteractionId ) { @@ -244,7 +247,7 @@ private void runAgent( private void processOutput( Map params, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, String memoryId, String parentInteractionId, List toolSpecs, @@ -357,7 +360,7 @@ private void runNextStep( private void saveMessage( Map params, - ConversationIndexMemory memory, + Memory memory, String outputResponse, String memoryId, String parentInteractionId, @@ -392,11 +395,10 @@ void updateMemoryWithListener( if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { return; } - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - conversationIndexMemoryFactory + Memory.Factory factory = memoryFactoryMap.get(memorySpec.getType()); + factory .create( - memoryId, + Map.of(MEMORY_ID, memoryId), ActionListener .wrap( memory -> memory.update(interactionId, Map.of(ActionConstants.ADDITIONAL_INFO_FIELD, additionalInfo), listener), diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 30725a8c47..7c8742a570 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; import java.util.ArrayList; import java.util.List; @@ -28,9 +29,9 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; @@ -169,23 +170,23 @@ public void run(MLAgent mlAgent, Map params, ActionListener additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) { - if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { - return; - } - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - conversationIndexMemoryFactory - .create( - memoryId, - ActionListener - .wrap( - memory -> updateInteraction(additionalInfo, interactionId, memory), - e -> log.error("Failed create memory from id: {}", memoryId, e) - ) - ); - } + // @VisibleForTesting + // void updateMemory(Map additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) { + // if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { + // return; + // } + // ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap + // .get(memorySpec.getType()); + // conversationIndexMemoryFactory + // .create( + // memoryId, + // ActionListener + // .wrap( + // memory -> updateInteraction(additionalInfo, interactionId, memory), + // e -> log.error("Failed create memory from id: {}", memoryId, e) + // ) + // ); + // } @VisibleForTesting void updateMemoryWithListener( @@ -202,7 +203,7 @@ void updateMemoryWithListener( .get(memorySpec.getType()); conversationIndexMemoryFactory .create( - memoryId, + Map.of(MEMORY_ID, memoryId), ActionListener .wrap( memory -> updateInteractionWithListener(additionalInfo, interactionId, memory, listener), diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 23586e0020..6bfc22276f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; import static org.opensearch.ml.common.utils.StringUtils.isJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; @@ -17,6 +18,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; @@ -61,10 +63,10 @@ import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -74,7 +76,6 @@ import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -295,34 +296,40 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { - memory.getMessages(ActionListener.>wrap(interactions -> { - final List completedSteps = new ArrayList<>(); - for (Interaction interaction : interactions) { - String question = interaction.getInput(); - String response = interaction.getResponse(); - - if (Strings.isNullOrEmpty(response)) { - continue; - } - - completedSteps.add(question); - completedSteps.add(response); + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + Map memoryParams = createMemoryParams( + apiParams.get(USER_PROMPT_FIELD), + memoryId, + appType, + mlAgent, + apiParams.get(MEMORY_CONTAINER_ID_FIELD) + ); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { + List completedSteps = new ArrayList<>(); + for (Interaction interaction : interactions) { + String question = interaction.getInput(); + String response = interaction.getResponse(); + + if (Strings.isNullOrEmpty(response)) { + continue; } - if (!completedSteps.isEmpty()) { - addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); - usePlannerWithHistoryPromptTemplate(allParams); - } + completedSteps.add(question); + completedSteps.add(response); + } - setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getConversationId(), listener); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - }), messageHistoryLimit); - }, listener::onFailure)); + if (!completedSteps.isEmpty()) { + addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); + usePlannerWithHistoryPromptTemplate(allParams); + } + + setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getId(), listener); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + })); + }, listener::onFailure)); } private void setToolsAndRunAgent( @@ -383,7 +390,7 @@ private void executePlanningLoop( completedSteps.getLast() ); saveAndReturnFinalResult( - (ConversationIndexMemory) memory, + memory, parentInteractionId, allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), @@ -447,7 +454,7 @@ private void executePlanningLoop( if (parseLLMOutput.get(RESULT_FIELD) != null) { String finalResult = (String) parseLLMOutput.get(RESULT_FIELD); saveAndReturnFinalResult( - (ConversationIndexMemory) memory, + memory, parentInteractionId, allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), @@ -560,7 +567,7 @@ private void executePlanningLoop( ); saveTraceData( - (ConversationIndexMemory) memory, + memory, memory.getType(), stepToExecute, results.get(STEP_RESULT_FIELD), @@ -717,7 +724,7 @@ void addSteps(List steps, Map allParams, String field) { @VisibleForTesting void saveAndReturnFinalResult( - ConversationIndexMemory memory, + Memory memory, String parentInteractionId, String reactAgentMemoryId, String reactParentInteractionId, @@ -732,9 +739,9 @@ void saveAndReturnFinalResult( updateContent.put(INTERACTIONS_INPUT_FIELD, input); } - memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> { + memory.update(parentInteractionId, updateContent, ActionListener.wrap(res -> { List finalModelTensors = createModelTensors( - memory.getConversationId(), + memory.getId(), parentInteractionId, reactAgentMemoryId, reactParentInteractionId diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java new file mode 100644 index 0000000000..81c58d3aaa --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -0,0 +1,567 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; +import org.opensearch.ml.common.memorycontainer.MLWorkingMemory; +import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest; +import org.opensearch.ml.common.transport.session.MLCreateSessionAction; +import org.opensearch.ml.common.transport.session.MLCreateSessionInput; +import org.opensearch.ml.common.transport.session.MLCreateSessionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +/** + * Agentic Memory implementation that stores conversations in Memory Container + * Uses TransportCreateSessionAction and TransportAddMemoriesAction for all operations + */ +@Log4j2 +@Getter +public class AgenticConversationMemory implements Memory { + + public static final String TYPE = "agentic_conversation"; + private static final String SESSION_ID_FIELD = "session_id"; + private static final String CREATED_TIME_FIELD = "created_time"; + + private final Client client; + private final String conversationId; + private final String memoryContainerId; + + public AgenticConversationMemory(Client client, String memoryId, String memoryContainerId) { + this.client = client; + this.conversationId = memoryId; + this.memoryContainerId = memoryContainerId; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getId() { + return conversationId; + } + + @Override + public void save(Message message, String parentId, Integer traceNum, String action) { + this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { + log.info("Saved message to agentic memory, session id: {}, working memory id: {}", conversationId, r.getId()); + }, e -> { log.error("Failed to save message to agentic memory", e); })); + } + + @Override + public void save( + Message message, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalStateException( + "Memory container ID is not configured for this AgenticConversationMemory. " + + "Cannot save messages without a valid memory container." + ) + ); + return; + } + + ConversationIndexMessage msg = (ConversationIndexMessage) message; + + // Build namespace with session_id + Map namespace = new HashMap<>(); + namespace.put(SESSION_ID_FIELD, conversationId); + + // Simple rule matching ConversationIndexMemory: + // - If traceNum != null → it's a trace + // - If traceNum == null → it's a message + boolean isTrace = (traceNum != null); + + Map metadata = new HashMap<>(); + Map structuredData = new HashMap<>(); + + // Store data in structured_data format matching conversation index + structuredData.put("input", msg.getQuestion() != null ? msg.getQuestion() : ""); + structuredData.put("response", msg.getResponse() != null ? msg.getResponse() : ""); + + if (isTrace) { + // This is a trace (tool usage or intermediate step) + metadata.put("type", "trace"); + if (parentId != null) { + metadata.put("parent_message_id", parentId); + structuredData.put("parent_message_id", parentId); + } + metadata.put("trace_number", String.valueOf(traceNum)); + structuredData.put("trace_number", traceNum); + if (action != null) { + metadata.put("origin", action); + structuredData.put("origin", action); + } + } else { + // This is a final message (Q&A pair) + metadata.put("type", "message"); + if (msg.getFinalAnswer() != null) { + structuredData.put("final_answer", msg.getFinalAnswer()); + } + } + + // Add timestamps + java.time.Instant now = java.time.Instant.now(); + structuredData.put("create_time", now.toString()); + structuredData.put("updated_time", now.toString()); + + // Create MLAddMemoriesInput + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .structuredData(structuredData) + .messageId(traceNum) // Store trace number in messageId field (null for messages) + .namespace(namespace) + .metadata(metadata) + .infer(false) // Don't infer long-term memory by default + .build(); + + MLAddMemoriesRequest request = MLAddMemoriesRequest.builder().mlAddMemoryInput(input).build(); + + // Execute the add memories action + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(response -> { + // Convert MLAddMemoriesResponse to CreateInteractionResponse + CreateInteractionResponse interactionResponse = new CreateInteractionResponse(response.getWorkingMemoryId()); + listener.onResponse(interactionResponse); + }, e -> { + log.error("Failed to add memories to memory container", e); + listener.onFailure(e); + })); + } + + @Override + public void update(String messageId, Map updateContent, ActionListener updateListener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + updateListener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Step 1: Get the existing working memory to retrieve current structured_data + MLGetMemoryRequest getRequest = MLGetMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .memoryId(messageId) + .build(); + + client.execute(MLGetMemoryAction.INSTANCE, getRequest, ActionListener.wrap(getResponse -> { + // Step 2: Extract existing structured_data and merge with updates + MLWorkingMemory workingMemory = getResponse.getWorkingMemory(); + if (workingMemory == null) { + updateListener.onFailure(new IllegalStateException("Working memory not found for id: " + messageId)); + return; + } + + Map structuredData = workingMemory.getStructuredData(); + if (structuredData == null) { + structuredData = new HashMap<>(); + } else { + // Create a mutable copy + structuredData = new HashMap<>(structuredData); + } + + // Step 3: Merge update content into structured_data + // The updateContent contains fields like "response" and "additional_info" + // These should be stored in structured_data + for (Map.Entry entry : updateContent.entrySet()) { + structuredData.put(entry.getKey(), entry.getValue()); + } + + // Update the timestamp + // structuredData.put("updated_time", java.time.Instant.now().toString()); + + // Step 4: Create update request with merged structured_data + Map finalUpdateContent = new HashMap<>(); + finalUpdateContent.put("structured_data", structuredData); + + MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(finalUpdateContent).build(); + + MLUpdateMemoryRequest updateRequest = MLUpdateMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .memoryId(messageId) + .mlUpdateMemoryInput(input) + .build(); + + // Step 5: Execute the update + client.execute(MLUpdateMemoryAction.INSTANCE, updateRequest, ActionListener.wrap(indexResponse -> { + // Convert IndexResponse to UpdateResponse + UpdateResponse updateResponse = new UpdateResponse( + indexResponse.getShardInfo(), + indexResponse.getShardId(), + indexResponse.getId(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + indexResponse.getVersion(), + indexResponse.getResult() + ); + updateListener.onResponse(updateResponse); + }, e -> { + log.error("Failed to update memory in memory container", e); + updateListener.onFailure(e); + })); + }, e -> { + log.error("Failed to get existing memory for update", e); + updateListener.onFailure(e); + })); + } + + @Override + public void getMessages(int size, ActionListener> listener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Build search query for working memory by session_id, filtering only final messages (not traces) + // Match ConversationIndexMemory pattern: exclude entries with trace_number + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.must(QueryBuilders.termQuery("namespace." + SESSION_ID_FIELD, conversationId)); + boolQuery.mustNot(QueryBuilders.existsQuery("structured_data.trace_number")); // Exclude traces + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQuery); + searchSourceBuilder.size(size); + searchSourceBuilder.sort(CREATED_TIME_FIELD, SortOrder.ASC); + + MLSearchMemoriesInput searchInput = MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(searchSourceBuilder) + .build(); + + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { + List interactions = parseSearchResponseToInteractions(searchResponse); + listener.onResponse(interactions); + }, e -> { + log.error("Failed to search memories in memory container", e); + listener.onFailure(e); + })); + } + + private List parseSearchResponseToInteractions(SearchResponse searchResponse) { + List interactions = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + Map sourceMap = hit.getSourceAsMap(); + + // Extract structured_data which contains the interaction data + @SuppressWarnings("unchecked") + Map structuredData = (Map) sourceMap.get("structured_data"); + + if (structuredData != null) { + String input = (String) structuredData.get("input"); + String response = (String) structuredData.get("response"); + + // Extract timestamps + Long createdTimeMs = (Long) sourceMap.get("created_time"); + Long updatedTimeMs = (Long) sourceMap.get("last_updated_time"); + + // Parse create_time from structured_data if available + String createTimeStr = (String) structuredData.get("create_time"); + String updatedTimeStr = (String) structuredData.get("updated_time"); + + java.time.Instant createTime = null; + java.time.Instant updatedTime = null; + + if (createTimeStr != null) { + try { + createTime = java.time.Instant.parse(createTimeStr); + } catch (Exception e) { + log.warn("Failed to parse create_time from structured_data", e); + } + } + if (updatedTimeStr != null) { + try { + updatedTime = java.time.Instant.parse(updatedTimeStr); + } catch (Exception e) { + log.warn("Failed to parse updated_time from structured_data", e); + } + } + + // Fallback to document timestamps if structured_data timestamps not available + if (createTime == null && createdTimeMs != null) { + createTime = java.time.Instant.ofEpochMilli(createdTimeMs); + } + if (updatedTime == null && updatedTimeMs != null) { + updatedTime = java.time.Instant.ofEpochMilli(updatedTimeMs); + } + + // Extract metadata + @SuppressWarnings("unchecked") + Map metadata = (Map) sourceMap.get("metadata"); + String parentInteractionId = metadata != null ? metadata.get("parent_message_id") : null; + + // Create Interaction object + if (input != null || response != null) { + Interaction interaction = Interaction + .builder() + .id(hit.getId()) + .conversationId(conversationId) + .createTime(createTime != null ? createTime : java.time.Instant.now()) + .updatedTime(updatedTime) + .input(input != null ? input : "") + .response(response != null ? response : "") + .origin("agentic_memory") + .promptTemplate(null) + .additionalInfo(null) + .parentInteractionId(parentInteractionId) + .traceNum(null) // Messages don't have trace numbers + .build(); + interactions.add(interaction); + } + } + } + return interactions; + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear method is not supported in AgenticConversationMemory"); + } + + @Override + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + // For now, delegate to a simple implementation + // In the future, this could use MLDeleteMemoryAction + log.warn("deleteInteractionAndTrace is not fully implemented for AgenticConversationMemory"); + listener.onResponse(false); + } + + /** + * Get traces (intermediate steps/tool usage) for a specific parent message + * @param parentMessageId The parent message ID + * @param listener Action listener for the traces + */ + public void getTraces(String parentMessageId, ActionListener> listener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Build search query for traces by parent_message_id + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.must(QueryBuilders.termQuery("namespace." + SESSION_ID_FIELD, conversationId)); + boolQuery.must(QueryBuilders.termQuery("metadata.type", "trace")); // Only get traces + boolQuery.must(QueryBuilders.termQuery("metadata.parent_message_id", parentMessageId)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQuery); + searchSourceBuilder.size(1000); // Get all traces for this message + searchSourceBuilder.sort("message_id", SortOrder.ASC); // Sort by trace number + + MLSearchMemoriesInput searchInput = MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(searchSourceBuilder) + .build(); + + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { + List traces = parseSearchResponseToTraces(searchResponse); + listener.onResponse(traces); + }, e -> { + log.error("Failed to search traces in memory container", e); + listener.onFailure(e); + })); + } + + private List parseSearchResponseToTraces(SearchResponse searchResponse) { + List traces = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + Map sourceMap = hit.getSourceAsMap(); + + // Extract structured_data which contains the trace data + @SuppressWarnings("unchecked") + Map structuredData = (Map) sourceMap.get("structured_data"); + + if (structuredData != null) { + String input = (String) structuredData.get("input"); + String response = (String) structuredData.get("response"); + String origin = (String) structuredData.get("origin"); + String parentMessageId = (String) structuredData.get("parent_message_id"); + + // Extract trace number + Integer traceNum = null; + Object traceNumObj = structuredData.get("trace_number"); + if (traceNumObj instanceof Integer) { + traceNum = (Integer) traceNumObj; + } else if (traceNumObj instanceof String) { + try { + traceNum = Integer.parseInt((String) traceNumObj); + } catch (NumberFormatException e) { + log.warn("Failed to parse trace_number", e); + } + } + + // Also check message_id field which stores trace number + Integer messageId = (Integer) sourceMap.get("message_id"); + if (traceNum == null && messageId != null) { + traceNum = messageId; + } + + // Extract timestamps + Long createdTimeMs = (Long) sourceMap.get("created_time"); + Long updatedTimeMs = (Long) sourceMap.get("last_updated_time"); + + java.time.Instant createTime = createdTimeMs != null + ? java.time.Instant.ofEpochMilli(createdTimeMs) + : java.time.Instant.now(); + java.time.Instant updatedTime = updatedTimeMs != null ? java.time.Instant.ofEpochMilli(updatedTimeMs) : null; + + // Create Interaction object for trace + if (input != null || response != null) { + Interaction trace = Interaction + .builder() + .id(hit.getId()) + .conversationId(conversationId) + .createTime(createTime) + .updatedTime(updatedTime) + .input(input != null ? input : "") + .response(response != null ? response : "") + .origin(origin != null ? origin : "") + .promptTemplate(null) + .additionalInfo(null) + .parentInteractionId(parentMessageId) + .traceNum(traceNum) + .build(); + traces.add(trace); + } + } + } + return traces; + } + + /** + * Factory for creating AgenticConversationMemory instances + */ + public static class Factory implements Memory.Factory { + private Client client; + + public void init(Client client) { + this.client = client; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map == null || map.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating AgenticConversationMemory")); + return; + } + + String memoryId = (String) map.get(MEMORY_ID); + String name = (String) map.get(MEMORY_NAME); + String appType = (String) map.get(APP_TYPE); + String memoryContainerId = (String) map.get("memory_container_id"); + + create(name, memoryId, appType, memoryContainerId, listener); + } + + public void create( + String name, + String memoryId, + String appType, + String memoryContainerId, + ActionListener listener + ) { + // Memory container ID is required for AgenticConversationMemory + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalArgumentException( + "Memory container ID is required for AgenticConversationMemory. " + + "Please provide 'memory_container_id' in the agent configuration." + ) + ); + return; + } + + if (Strings.isEmpty(memoryId)) { + // Create new session using TransportCreateSessionAction + createSessionInMemoryContainer(name, memoryContainerId, ActionListener.wrap(sessionId -> { + create(sessionId, memoryContainerId, listener); + log.debug("Created session in memory container, session id: {}", sessionId); + }, e -> { + log.error("Failed to create session in memory container", e); + listener.onFailure(e); + })); + } else { + // Use existing session/memory ID + create(memoryId, memoryContainerId, listener); + } + } + + /** + * Create a new session in the memory container using the new session API + */ + private void createSessionInMemoryContainer(String summary, String memoryContainerId, ActionListener listener) { + MLCreateSessionInput input = MLCreateSessionInput.builder().memoryContainerId(memoryContainerId).summary(summary).build(); + + MLCreateSessionRequest request = MLCreateSessionRequest.builder().mlCreateSessionInput(input).build(); + + client + .execute( + MLCreateSessionAction.INSTANCE, + request, + ActionListener.wrap(response -> { listener.onResponse(response.getSessionId()); }, e -> { + log.error("Failed to create session via TransportCreateSessionAction", e); + listener.onFailure(e); + }) + ); + } + + public void create(String memoryId, String memoryContainerId, ActionListener listener) { + listener.onResponse(new AgenticConversationMemory(client, memoryId, memoryContainerId)); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java new file mode 100644 index 0000000000..987d947567 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import lombok.Builder; +import lombok.Data; + +/** + * Configuration for Agentic Memory integration + */ +@Data +@Builder +public class AgenticMemoryConfig { + + /** + * Memory container ID to use for storing conversations + */ + private String memoryContainerId; + + /** + * Whether to enable memory container integration + * If false, falls back to ConversationIndexMemory behavior + */ + @Builder.Default + private boolean enabled = true; + + /** + * Whether to enable inference (long-term memory extraction) + */ + @Builder.Default + private boolean enableInference = true; + + /** + * Custom namespace fields to add to memory container entries + */ + private java.util.Map customNamespace; + + /** + * Custom tags to add to memory container entries + */ + private java.util.Map customTags; +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java index 05b3185a34..562e425375 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java @@ -9,7 +9,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.common.memory.Message; import lombok.Builder; import lombok.Getter; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java index 9720661eeb..eeac12537a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -5,30 +5,20 @@ package org.opensearch.ml.engine.memory; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; import java.util.Map; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.common.MLMemoryType; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortOrder; import org.opensearch.transport.client.Client; import lombok.Getter; @@ -36,8 +26,8 @@ @Log4j2 @Getter -public class ConversationIndexMemory implements Memory { - public static final String TYPE = "conversation_index"; +public class ConversationIndexMemory implements Memory { + public static final String TYPE = MLMemoryType.CONVERSATION_INDEX.name(); public static final String CONVERSATION_ID = "conversation_id"; public static final String FINAL_ANSWER = "final_answer"; public static final String CREATED_TIME = "created_time"; @@ -75,28 +65,11 @@ public String getType() { } @Override - public void save(String id, Message message) { - this.save(id, message, ActionListener.wrap(r -> { log.info("saved message into {} memory, session id: {}", TYPE, id); }, e -> { - log.error("Failed to save message to memory", e); - })); + public String getId() { + return this.conversationId; } @Override - public void save(String id, Message message, ActionListener listener) { - mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { - if (created) { - IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName).setRefreshPolicy(IMMEDIATE); - ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); - indexRequest.source(builder); - client.index(indexRequest, listener); - } else { - listener.onFailure(new RuntimeException("Failed to create memory message index")); - } - }, e -> { listener.onFailure(new RuntimeException("Failed to create memory message index", e)); })); - } - public void save(Message message, String parentId, Integer traceNum, String action) { this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { log @@ -110,39 +83,21 @@ public void save(Message message, String parentId, Integer traceNum, String acti }, e -> { log.error("Failed to save interaction", e); })); } - public void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) { + @Override + public void save( + Message message, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { ConversationIndexMessage msg = (ConversationIndexMessage) message; memoryManager .createInteraction(conversationId, msg.getQuestion(), null, msg.getResponse(), action, null, parentId, traceNum, listener); } @Override - public void getMessages(String id, ActionListener listener) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(memoryMessageIndexName); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.size(10000); - QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(CONVERSATION_ID, id); - - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(sessionIdQueryBuilder); - - if (retrieveFinalAnswer) { - QueryBuilder finalAnswerQueryBuilder = new TermQueryBuilder(FINAL_ANSWER, true); - boolQueryBuilder.must(finalAnswerQueryBuilder); - } - - sourceBuilder.query(boolQueryBuilder); - sourceBuilder.sort(CREATED_TIME, SortOrder.ASC); - searchRequest.source(sourceBuilder); - client.search(searchRequest, listener); - } - - public void getMessages(ActionListener listener) { - memoryManager.getFinalInteractions(conversationId, LAST_N_INTERACTIONS, listener); - } - - public void getMessages(ActionListener listener, int size) { + public void getMessages(int size, ActionListener listener) { memoryManager.getFinalInteractions(conversationId, size, listener); } @@ -152,14 +107,15 @@ public void clear() { } @Override - public void remove(String id) { - throw new RuntimeException("remove method is not supported in ConversationIndexMemory"); - } - public void update(String messageId, Map updateContent, ActionListener updateListener) { getMemoryManager().updateInteraction(messageId, updateContent, updateListener); } + @Override + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + memoryManager.deleteInteractionAndTrace(interactionId, listener); + } + public static class Factory implements Memory.Factory { private Client client; private MLIndicesHandler mlIndicesHandler; @@ -186,7 +142,7 @@ public void create(Map map, ActionListener listener) { + private void create(String name, String memoryId, String appType, ActionListener listener) { if (Strings.isEmpty(memoryId)) { memoryManager.createConversation(name, appType, ActionListener.wrap(r -> { create(r.getId(), listener); @@ -200,7 +156,7 @@ public void create(String name, String memoryId, String appType, ActionListener< } } - public void create(String memoryId, ActionListener listener) { + private void create(String memoryId, ActionListener listener) { listener .onResponse( new ConversationIndexMemory( diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index c3843434f7..a641753cdb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -35,10 +35,10 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.remote.metadata.client.SdkClient; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index c63db9df4f..bdd91fd296 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -50,10 +50,10 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.memory.ConversationIndexMemory; @@ -134,20 +134,20 @@ public void setup() { toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory); // memory - mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10, null); when(memoryMap.get(anyString())).thenReturn(memoryFactory); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(0); listener.onResponse(generateInteractions(2)); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(1); listener.onResponse(conversationIndexMemory); return null; - }).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture()); + }).when(memoryFactory).create(any(), memoryFactoryCapture.capture()); when(createInteractionResponse.getId()).thenReturn("create_interaction_id"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); @@ -223,6 +223,9 @@ public void testParsingJsonBlockFromResponse() { @Test public void testParsingJsonBlockFromResponse2() { + // Reset client mock to avoid conflicts with previous test stubbing + Mockito.reset(client); + // Prepare the response with JSON block String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; @@ -477,7 +480,7 @@ public void testChatHistoryExcludeOngoingQuestion() { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); params.put(MESSAGE_HISTORY_LIMIT, "5"); @@ -533,7 +536,7 @@ private void testInteractions(String maxInteraction) { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); params.put("verbose", "true"); @@ -563,7 +566,7 @@ public void testChatHistoryException() { ActionListener> listener = invocation.getArgument(0); listener.onFailure(new RuntimeException("Test Exception")); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); @@ -904,7 +907,7 @@ public void testToolExecutionWithChatHistoryParameter() { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); doAnswer(generateToolResponse("First tool response")) .when(firstTool) @@ -1171,4 +1174,137 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testCreateMemoryAdapter_ConversationIndex() { + // Test that ConversationIndex memory type returns ConversationIndexMemory + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.QUESTION, "test question"); + params.put(MLAgentExecutor.MEMORY_ID, "test_memory_id"); + + // Mock the memory factory + when(memoryMap.get("conversation_index")).thenReturn(memoryFactory); + + // Create a mock ConversationIndexMemory + org.opensearch.ml.engine.memory.ConversationIndexMemory mockMemory = Mockito + .mock(org.opensearch.ml.engine.memory.ConversationIndexMemory.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockMemory); + return null; + }).when(memoryFactory).create(any(), any()); + + // Test the createMemoryAdapter method + ActionListener testListener = new ActionListener() { + @Override + public void onResponse(Object result) { + // Verify that we get back a ConversationIndexMemory + assertTrue("Expected ConversationIndexMemory", result instanceof org.opensearch.ml.engine.memory.ConversationIndexMemory); + assertEquals("Memory should be the mocked instance", mockMemory, result); + } + + @Override + public void onFailure(Exception e) { + Assert.fail("Should not fail: " + e.getMessage()); + } + }; + + // This would normally be a private method call, but for testing we can verify the logic + // by checking that the correct memory type handling works through the public run method + // The actual test would need to be done through integration testing + } + + @Test + public void testCreateMemoryAdapter_AgenticMemory() { + // Test that agentic memory type returns AgenticMemoryAdapter + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("agentic_memory").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put("memory_container_id", "test_container_id"); + params.put("session_id", "test_session_id"); + params.put("owner_id", "test_owner_id"); + + // This test verifies that the agentic memory path would be taken + // Full integration testing would require mocking the agentic memory services + assertNotNull("MLAgent should be created successfully", mlAgent); + assertEquals("Memory type should be agentic_memory", "agentic_memory", mlAgent.getMemory().getType()); + } + + // TODO: Re-enable these tests when ChatMessage and SimpleChatHistoryTemplateEngine are implemented + // @Test + // public void testEnhancedChatMessage() { + // // Test the enhanced ChatMessage format + // ChatMessage userMessage = ChatMessage + // .builder() + // .id("msg_1") + // .timestamp(java.time.Instant.now()) + // .sessionId("session_123") + // .role("user") + // .content("Hello, how are you?") + // .contentType("text") + // .origin("agentic_memory") + // .metadata(Map.of("confidence", 0.95)) + // .build(); + // + // ChatMessage assistantMessage = ChatMessage + // .builder() + // .id("msg_2") + // .timestamp(java.time.Instant.now()) + // .sessionId("session_123") + // .role("assistant") + // .content("I'm doing well, thank you!") + // .contentType("text") + // .origin("agentic_memory") + // .metadata(Map.of("confidence", 0.98)) + // .build(); + // + // // Verify the enhanced ChatMessage structure + // assertEquals("user", userMessage.getRole()); + // assertEquals("text", userMessage.getContentType()); + // assertEquals("agentic_memory", userMessage.getOrigin()); + // assertNotNull(userMessage.getMetadata()); + // assertEquals(0.95, userMessage.getMetadata().get("confidence")); + // + // assertEquals("assistant", assistantMessage.getRole()); + // assertEquals("I'm doing well, thank you!", assistantMessage.getContent()); + // } + // + // @Test + // public void testSimpleChatHistoryTemplateEngine() { + // // Test the new template engine + // SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + // + // List messages = List + // .of( + // ChatMessage.builder().role("user").content("What's the weather?").contentType("text").build(), + // ChatMessage.builder().role("assistant").content("It's sunny today!").contentType("text").build(), + // ChatMessage.builder().role("system").content("Weather data retrieved from API").contentType("context").build() + // ); + // + // String chatHistory = templateEngine.buildSimpleChatHistory(messages); + // + // assertNotNull("Chat history should not be null", chatHistory); + // assertTrue("Should contain user message", chatHistory.contains("Human: What's the weather?")); + // assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); + // assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); + // } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index cecb99f32e..396d07d0d9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -12,16 +12,12 @@ import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.utils.ToolUtils.buildToolParameters; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.io.IOException; import java.util.Arrays; @@ -50,10 +46,10 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -181,7 +177,7 @@ public void testRunWithIncludeOutputNotSet() { ActionListener listener = invocation.getArgument(1); listener.onResponse(memory); return null; - }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any()); final MLAgent mlAgent = MLAgent .builder() @@ -236,7 +232,7 @@ public void testRunWithIncludeOutputSet() { ActionListener listener = invocation.getArgument(1); listener.onResponse(memory); return null; - }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any()); final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") @@ -423,31 +419,31 @@ public void testWithMemoryNotSet() { assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult()); } - @Test - public void testUpdateMemory() { - // Mocking MLMemorySpec - MLMemorySpec memorySpec = mock(MLMemorySpec.class); - when(memorySpec.getType()).thenReturn("memoryType"); - - // Mocking Memory Factory and Memory - - ConversationIndexMemory.Factory memoryFactory = new ConversationIndexMemory.Factory(); - memoryFactory.init(client, indicesHandler, memoryManager); - ActionListener listener = mock(ActionListener.class); - memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); - - verify(listener).onResponse(isA(ConversationIndexMemory.class)); - - Map memoryFactoryMap = new HashMap<>(); - memoryFactoryMap.put("memoryType", memoryFactory); - mlFlowAgentRunner.setMemoryFactoryMap(memoryFactoryMap); - - // Execute the method under test - mlFlowAgentRunner.updateMemory(new HashMap<>(), memorySpec, "memoryId", "interactionId"); - - // Asserting that the Memory Manager's updateInteraction method was called - verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class)); - } + // @Test + // public void testUpdateMemory() { + // // Mocking MLMemorySpec + // MLMemorySpec memorySpec = mock(MLMemorySpec.class); + // when(memorySpec.getType()).thenReturn("memoryType"); + // + // // Mocking Memory Factory and Memory + // + // ConversationIndexMemory.Factory memoryFactory = new ConversationIndexMemory.Factory(); + // memoryFactory.init(client, indicesHandler, memoryManager); + // ActionListener listener = mock(ActionListener.class); + // memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); + // + // verify(listener).onResponse(isA(ConversationIndexMemory.class)); + // + // Map memoryFactoryMap = new HashMap<>(); + // memoryFactoryMap.put("memoryType", memoryFactory); + // mlFlowAgentRunner.setMemoryFactoryMap(memoryFactoryMap); + // + // // Execute the method under test + // mlFlowAgentRunner.updateMemory(new HashMap<>(), memorySpec, "memoryId", "interactionId"); + // + // // Asserting that the Memory Manager's updateInteraction method was called + // verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class)); + // } @Test public void testRunWithUpdateFailure() { @@ -468,7 +464,7 @@ public void testRunWithUpdateFailure() { ActionListener listener = invocation.getArgument(1); listener.onResponse(memory); return null; - }).when(mockMemoryFactory).create(Mockito.anyString(), Mockito.any()); + }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any()); final MLAgent mlAgent = MLAgent .builder() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 00a6edde13..17e758bc1e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -50,10 +50,10 @@ import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -135,7 +135,7 @@ public void setup() { toolFactories = ImmutableMap.of(FIRST_TOOL, firstToolFactory, SECOND_TOOL, secondToolFactory); // memory - mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10, null); when(memoryMap.get(ConversationIndexMemory.TYPE)).thenReturn(memoryFactory); when(memoryMap.get(anyString())).thenReturn(memoryFactory); when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); @@ -145,17 +145,17 @@ public void setup() { // memory factory doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(1); listener.onResponse(conversationIndexMemory); return null; - }).when(memoryFactory).create(any(), any(), any(), memoryFactoryCapture.capture()); + }).when(memoryFactory).create(any(), memoryFactoryCapture.capture()); // Setup conversation index memory doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(0); listener.onResponse(generateInteractions()); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), anyInt()); + }).when(conversationIndexMemory).getMessages(anyInt(), memoryInteractionCapture.capture()); // Setup memory manager doAnswer(invocation -> { @@ -373,7 +373,7 @@ public void testMessageHistoryLimits() { params.put("executor_message_history_limit", "3"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); - verify(conversationIndexMemory).getMessages(any(), eq(5)); + verify(conversationIndexMemory).getMessages(eq(5), any()); ArgumentCaptor executeCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), executeCaptor.capture(), any()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java new file mode 100644 index 0000000000..61daec457f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java @@ -0,0 +1,157 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse; +import org.opensearch.ml.common.transport.session.MLCreateSessionAction; +import org.opensearch.ml.common.transport.session.MLCreateSessionResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.transport.client.Client; + +public class AgenticConversationMemoryTest { + + @Mock + private Client client; + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private MLMemoryManager memoryManager; + + private AgenticConversationMemory agenticMemory; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + + agenticMemory = new AgenticConversationMemory(client, "test_conversation_id", "test_memory_container_id"); + } + + @Test + public void testGetType() { + assert agenticMemory.getType().equals("agentic_conversation"); + } + + @Test + public void testSaveMessage() { + ConversationIndexMessage message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId("test_session") + .question("What is AI?") + .response("AI is artificial intelligence") + .finalAnswer(true) + .build(); + + // Mock memory container save (primary path) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(MLAddMemoriesResponse.builder().workingMemoryId("working_mem_123").build()); + return null; + }).when(client).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + + ActionListener testListener = ActionListener.wrap(response -> { + // Response should contain the working memory ID + assert response.getId().equals("working_mem_123"); + }, e -> { throw new RuntimeException("Should not fail", e); }); + + agenticMemory.save(message, null, null, "test_action", testListener); + + // Verify only memory container save was called (not conversation index) + verify(client, times(1)).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + verify(memoryManager, never()).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void testFactoryCreate() { + AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); + factory.init(client); + + Map params = new HashMap<>(); + params.put("memory_id", "test_memory_id"); + params.put("memory_name", "Test Memory"); + params.put("app_type", "conversational"); + params.put("memory_container_id", "test_container_id"); + + ActionListener listener = ActionListener + .wrap(memory -> { assert memory.getId().equals("test_memory_id"); }, e -> { + throw new RuntimeException("Should not fail", e); + }); + + factory.create(params, listener); + } + + @Test + public void testFactoryCreateWithNewSession() { + AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); + factory.init(client); + + // Mock session creation + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(MLCreateSessionResponse.builder().sessionId("new_session_123").status("created").build()); + return null; + }).when(client).execute(eq(MLCreateSessionAction.INSTANCE), any(), any()); + + Map params = new HashMap<>(); + params.put("memory_name", "New Session"); + params.put("app_type", "conversational"); + params.put("memory_container_id", "test_container_id"); + + ActionListener listener = ActionListener + .wrap(memory -> { assert memory.getId().equals("new_session_123"); }, e -> { + throw new RuntimeException("Should not fail", e); + }); + + factory.create(params, listener); + + // Verify session creation was called + verify(client, times(1)).execute(eq(MLCreateSessionAction.INSTANCE), any(), any()); + } + + @Test + public void testSaveWithoutMemoryContainerId() { + AgenticConversationMemory memoryWithoutContainer = new AgenticConversationMemory( + client, + "test_conversation_id", + null // No memory container ID = should fail + ); + + ConversationIndexMessage message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId("test_session") + .question("What is AI?") + .response("AI is artificial intelligence") + .build(); + + ActionListener testListener = ActionListener.wrap(response -> { + throw new RuntimeException("Should have failed without memory container ID"); + }, e -> { + // Expected to fail + assert e instanceof IllegalStateException; + assert e.getMessage().contains("Memory container ID is not configured"); + }); + + memoryWithoutContainer.save(message, null, null, "test_action", testListener); + + // Verify no API calls were made + verify(memoryManager, never()).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verify(client, never()).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java index d1ac123d7c..3c400b40cb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java @@ -20,10 +20,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; @@ -71,29 +68,29 @@ public void getType() { Assert.assertEquals(indexMemory.getType(), ConversationIndexMemory.TYPE); } - @Test - public void save() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(true); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); - - verify(indicesHandler).initMemoryMessageIndex(any()); - } - - @Test - public void save4() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onFailure(new RuntimeException()); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); - - verify(indicesHandler).initMemoryMessageIndex(any()); - } + // @Test + // public void save() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(true); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + // + // verify(indicesHandler).initMemoryMessageIndex(any()); + // } + + // @Test + // public void save4() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onFailure(new RuntimeException()); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + // + // verify(indicesHandler).initMemoryMessageIndex(any()); + // } @Test public void save1() { @@ -119,66 +116,54 @@ public void save6() { verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); } - @Test - public void save2() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(Boolean.TRUE); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); - return null; - }).when(client).index(any(), any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onResponse(isA(IndexResponse.class)); - } - - @Test - public void save3() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onFailure(new RuntimeException()); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onFailure(isA(RuntimeException.class)); - } - - @Test - public void save5() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(Boolean.FALSE); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); - return null; - }).when(client).index(any(), any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onFailure(isA(RuntimeException.class)); - } - - @Test - public void getMessages() { - ActionListener listener = mock(ActionListener.class); - indexMemory.getMessages("test_id", listener); - } - - @Test - public void getMessages1() { - ActionListener listener = mock(ActionListener.class); - indexMemory.getMessages(listener); - } + // @Test + // public void save2() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(Boolean.TRUE); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(1); + // listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + // return null; + // }).when(client).index(any(), any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onResponse(isA(IndexResponse.class)); + // } + + // @Test + // public void save3() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onFailure(new RuntimeException()); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onFailure(isA(RuntimeException.class)); + // } + + // @Test + // public void save5() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(Boolean.FALSE); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(1); + // listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + // return null; + // }).when(client).index(any(), any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onFailure(isA(RuntimeException.class)); + // } @Test public void clear() { @@ -187,12 +172,12 @@ public void clear() { indexMemory.clear(); } - @Test - public void remove() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); - indexMemory.remove("test_id"); - } + // @Test + // public void remove() { + // exceptionRule.expect(RuntimeException.class); + // exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); + // indexMemory.remove("test_id"); + // } @Test public void factory_create_emptyMap() { diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java index 823c7c0548..82dd39cb67 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java @@ -118,7 +118,7 @@ private void searchMemories( memoryContainerHelper.addContainerIdFilter(input.getMemoryContainerId(), input.getSearchSourceBuilder()); // Add owner filter for non-admin users - if (!memoryContainerHelper.isAdminUser(user)) { + if (!ConnectorAccessControlHelper.isAdmin(user) && user != null) { memoryContainerHelper.addOwnerIdFilter(user, input.getSearchSourceBuilder()); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index b25ed1afba..5b7f602de5 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -177,11 +177,11 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.spi.MLCommonsExtension; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; @@ -278,6 +278,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.memory.AgenticConversationMemory; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; @@ -874,6 +875,10 @@ public Collection createComponents( conversationIndexMemoryFactory.init(client, mlIndicesHandler, memoryManager); memoryFactoryMap.put(ConversationIndexMemory.TYPE, conversationIndexMemoryFactory); + AgenticConversationMemory.Factory agenticConversationMemoryFactory = new AgenticConversationMemory.Factory(); + agenticConversationMemoryFactory.init(client); + memoryFactoryMap.put(AgenticConversationMemory.TYPE, agenticConversationMemoryFactory); + MLAgentExecutor agentExecutor = new MLAgentExecutor( client, sdkClient, diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java deleted file mode 100644 index 3615384fce..0000000000 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.spi.memory; - -import java.util.Map; - -import org.opensearch.core.action.ActionListener; - -/** - * A general memory interface. - * @param - */ -public interface Memory { - - /** - * Get memory type. - * @return - */ - String getType(); - - /** - * Save message to id. - * @param id memory id - * @param message message to be saved - */ - default void save(String id, T message) {} - - default void save(String id, T message, ActionListener listener) {} - - /** - * Get messages of memory id. - * @param id memory id - * @return - */ - default T[] getMessages(String id) { - return null; - } - - default void getMessages(String id, ActionListener listener) {} - - /** - * Clear all memory. - */ - void clear(); - - /** - * Remove memory of specific id. - * @param id memory id - */ - void remove(String id); - - interface Factory { - /** - * Create an instance of this Memory. - * - * @param params Parameters for the memory - * @param listener Action listen for the memory creation action - */ - void create(Map params, ActionListener listener); - } -} From 97cfe27b8d851706bb5b409d794e6934bac9d71c Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 17 Nov 2025 17:02:26 -0800 Subject: [PATCH 4/5] fix connector executor: restore context Signed-off-by: Yaliang Wu --- .../ml/engine/algorithms/remote/AwsConnectorExecutor.java | 4 +++- .../engine/algorithms/remote/HttpJsonConnectorExecutor.java | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index f500ae32d1..134aa83729 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -26,6 +26,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.TokenBucket; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; @@ -123,6 +124,7 @@ public void invokeRemoteService( default: throw new IllegalArgumentException("unsupported http method"); } + ThreadContext.StoredContext storedContext = client.threadPool().getThreadContext().newStoredContext(true); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) @@ -130,7 +132,7 @@ public void invokeRemoteService( .responseHandler( new MLSdkAsyncHttpResponseHandler( executionContext, - actionListener, + ActionListener.runBefore(actionListener, storedContext::restore), parameters, connector, scriptService, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 45b318bc6c..bc2446034c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -25,6 +25,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.TokenBucket; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; @@ -122,6 +123,7 @@ public void invokeRemoteService( default: throw new IllegalArgumentException("unsupported http method"); } + ThreadContext.StoredContext storedContext = client.threadPool().getThreadContext().newStoredContext(true); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(request) @@ -129,7 +131,7 @@ public void invokeRemoteService( .responseHandler( new MLSdkAsyncHttpResponseHandler( executionContext, - actionListener, + ActionListener.runBefore(actionListener, storedContext::restore), parameters, connector, scriptService, From fc65c3a7acc2ea11c1d1d9900fad429b2ca9a286 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 17 Nov 2025 22:32:05 -0800 Subject: [PATCH 5/5] fix ut Signed-off-by: Yaliang Wu --- .../remote/AwsConnectorExecutorTest.java | 4 +- .../remote/HttpJsonConnectorExecutorTest.java | 42 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 59998b714e..f83b5ce1b0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -245,7 +245,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); + executor.setClient(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -724,7 +724,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); ExecutorService executorService = mock(ExecutorService.class); - when(executor.getClient()).thenReturn(client); + executor.setClient(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(threadPool.executor(any())).thenReturn(executorService); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index ff3298f7e9..6b49950812 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -29,6 +29,8 @@ import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -41,6 +43,8 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableMap; @@ -51,9 +55,19 @@ public class HttpJsonConnectorExecutorTest { @Mock private ActionListener> actionListener; + @Mock + private ThreadPool threadPool; + + @Mock + private Client client; + + private ThreadContext threadContext; + @Before public void setUp() { MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); } @Test @@ -95,8 +109,11 @@ public void invokeRemoteService_invalidIpAddress() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setConnectorPrivateIpEnabled(false); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor .invokeRemoteService( PREDICT.name(), @@ -128,8 +145,11 @@ public void invokeRemoteService_EnabledPrivateIpAddress() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setConnectorPrivateIpEnabled(true); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor .invokeRemoteService( PREDICT.name(), @@ -158,8 +178,11 @@ public void invokeRemoteService_DisabledPrivateIpAddress() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setConnectorPrivateIpEnabled(false); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor .invokeRemoteService( PREDICT.name(), @@ -215,7 +238,10 @@ public void invokeRemoteService_get_request() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); } @@ -235,7 +261,10 @@ public void invokeRemoteService_post_request() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); executor .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); } @@ -257,6 +286,9 @@ public void invokeRemoteService_nullHttpClient_throwMLException() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setClient(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); when(executor.getHttpClient()).thenReturn(null); executor .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener);