diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 351171ede6..1f7bfac8ad 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -54,6 +54,7 @@ public class CommonValue { public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job"; public static final String MCP_SESSION_MANAGEMENT_INDEX = ".plugins-ml-mcp-session-management"; public static final String MCP_TOOLS_INDEX = ".plugins-ml-mcp-tools"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX = ".plugins-ml-context-management-templates"; // index created in 3.1 to track all ml jobs created via job scheduler public static final String ML_JOBS_INDEX = ".plugins-ml-jobs"; public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); @@ -76,6 +77,7 @@ public class CommonValue { public static final String ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH = "index-mappings/ml_memory_long_term_history.json"; public static final String ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_session_management.json"; public static final String ML_MCP_TOOLS_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_tools.json"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX_MAPPING_PATH = "index-mappings/ml_context_management_templates.json"; public static final String ML_JOBS_INDEX_MAPPING_PATH = "index-mappings/ml_jobs.json"; public static final String ML_INDEX_INSIGHT_CONFIG_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_config.json"; public static final String ML_INDEX_INSIGHT_STORAGE_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_storage.json"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..5e512a394f 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.telemetry.metrics.tags.Tags; import lombok.Builder; @@ -51,6 +52,8 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; public static final String APP_TYPE_FIELD = "app_type"; public static final String IS_HIDDEN_FIELD = "is_hidden"; + public static final String CONTEXT_MANAGEMENT_NAME_FIELD = "context_management_name"; + public static final String CONTEXT_MANAGEMENT_FIELD = "context_management"; private static final String LLM_INTERFACE_FIELD = "_llm_interface"; private static final String TAG_VALUE_UNKNOWN = "unknown"; private static final String TAG_MEMORY_TYPE = "memory_type"; @@ -58,6 +61,7 @@ public class MLAgent implements ToXContentObject, Writeable { public static final int AGENT_NAME_MAX_LENGTH = 128; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT = CommonValue.VERSION_3_3_0; private String name; private String type; @@ -71,6 +75,8 @@ public class MLAgent implements ToXContentObject, Writeable { private Instant lastUpdateTime; private String appType; private Boolean isHidden; + private String contextManagementName; + private ContextManagementTemplate contextManagement; private final String tenantId; @Builder(toBuilder = true) @@ -86,6 +92,8 @@ public MLAgent( Instant lastUpdateTime, String appType, Boolean isHidden, + String contextManagementName, + ContextManagementTemplate contextManagement, String tenantId ) { this.name = name; @@ -100,6 +108,8 @@ public MLAgent( this.appType = appType; // is_hidden field isn't going to be set by user. It will be set by the code. this.isHidden = isHidden; + this.contextManagementName = contextManagementName; + this.contextManagement = contextManagement; this.tenantId = tenantId; validate(); } @@ -128,6 +138,17 @@ private void validate() { } } } + validateContextManagement(); + } + + private void validateContextManagement() { + if (contextManagementName != null && contextManagement != null) { + throw new IllegalArgumentException("Cannot specify both context_management_name and context_management"); + } + + if (contextManagement != null && !contextManagement.isValid()) { + throw new IllegalArgumentException("Invalid context management configuration"); + } } private void validateMLAgentType(String agentType) { @@ -171,6 +192,12 @@ public MLAgent(StreamInput input) throws IOException { if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { isHidden = input.readOptionalBoolean(); } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + contextManagementName = input.readOptionalString(); + if (input.readBoolean()) { + contextManagement = new ContextManagementTemplate(input); + } + } this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; validate(); } @@ -214,6 +241,15 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { out.writeOptionalBoolean(isHidden); } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + out.writeOptionalString(contextManagementName); + if (contextManagement != null) { + out.writeBoolean(true); + contextManagement.writeTo(out); + } else { + out.writeBoolean(false); + } + } if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { out.writeOptionalString(tenantId); } @@ -256,6 +292,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isHidden != null) { builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); } + if (contextManagementName != null) { + builder.field(CONTEXT_MANAGEMENT_NAME_FIELD, contextManagementName); + } + if (contextManagement != null) { + builder.field(CONTEXT_MANAGEMENT_FIELD, contextManagement); + } if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } @@ -283,6 +325,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid Instant lastUpdateTime = null; String appType = null; boolean isHidden = false; + String contextManagementName = null; + ContextManagementTemplate contextManagement = null; String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -329,6 +373,12 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid if (parseHidden) isHidden = parser.booleanValue(); break; + case CONTEXT_MANAGEMENT_NAME_FIELD: + contextManagementName = parser.text(); + break; + case CONTEXT_MANAGEMENT_FIELD: + contextManagement = ContextManagementTemplate.parse(parser); + break; case TENANT_ID_FIELD: tenantId = parser.textOrNull(); break; @@ -351,6 +401,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid .lastUpdateTime(lastUpdateTime) .appType(appType) .isHidden(isHidden) + .contextManagementName(contextManagementName) + .contextManagement(contextManagement) .tenantId(tenantId) .build(); } @@ -384,4 +436,47 @@ public Tags getTags() { return tags; } + + /** + * Check if this agent has context management configuration + * @return true if agent has either context management name or inline configuration + */ + public boolean hasContextManagement() { + return contextManagementName != null || contextManagement != null; + } + + /** + * Get the effective context management configuration for this agent. + * This method prioritizes inline configuration over template reference. + * Note: Template resolution requires external service call and should be handled by the caller. + * + * @return the inline context management configuration, or null if using template reference or no configuration + */ + public ContextManagementTemplate getInlineContextManagement() { + return contextManagement; + } + + /** + * Check if this agent uses a context management template reference + * @return true if agent references a context management template by name + */ + public boolean hasContextManagementTemplate() { + return contextManagementName != null; + } + + /** + * Check if this agent has inline context management configuration + * @return true if agent has inline context management configuration + */ + public boolean hasInlineContextManagement() { + return contextManagement != null; + } + + /** + * Get the context management template name if this agent references one + * @return the template name, or null if no template reference + */ + public String getContextManagementTemplateName() { + return contextManagementName; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java new file mode 100644 index 0000000000..c1529e6eda --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for activation rules that determine when a context manager should execute. + * Activation rules evaluate runtime conditions based on the current context state. + */ +public interface ActivationRule { + + /** + * Evaluate whether the activation condition is met. + * @param context the current context state + * @return true if the condition is met and the manager should activate, false otherwise + */ + boolean evaluate(ContextManagerContext context); + + /** + * Get a description of this activation rule for logging and debugging. + * @return a human-readable description of the rule + */ + String getDescription(); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java new file mode 100644 index 0000000000..ccae1a4bf7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory class for creating activation rules from configuration. + * Supports creating rules from configuration maps and combining multiple rules. + */ +@Log4j2 +public class ActivationRuleFactory { + + public static final String TOKENS_EXCEED_KEY = "tokens_exceed"; + public static final String MESSAGE_COUNT_EXCEED_KEY = "message_count_exceed"; + + /** + * Create activation rules from a configuration map. + * @param activationConfig the configuration map containing rule definitions + * @return a list of activation rules, or empty list if no valid rules found + */ + public static List createRules(Map activationConfig) { + List rules = new ArrayList<>(); + + if (activationConfig == null || activationConfig.isEmpty()) { + return rules; + } + + // Create tokens_exceed rule + if (activationConfig.containsKey(TOKENS_EXCEED_KEY)) { + try { + Object tokenValue = activationConfig.get(TOKENS_EXCEED_KEY); + int tokenThreshold = parseIntegerValue(tokenValue, TOKENS_EXCEED_KEY); + if (tokenThreshold > 0) { + rules.add(new TokensExceedRule(tokenThreshold)); + log.debug("Created TokensExceedRule with threshold: {}", tokenThreshold); + } else { + throw new IllegalArgumentException("Invalid token threshold value: " + tokenValue + ". Must be positive integer."); + } + } catch (Exception e) { + log.error("Failed to create TokensExceedRule: {}", e.getMessage()); + } + } + + // Create message_count_exceed rule + if (activationConfig.containsKey(MESSAGE_COUNT_EXCEED_KEY)) { + try { + Object messageValue = activationConfig.get(MESSAGE_COUNT_EXCEED_KEY); + int messageThreshold = parseIntegerValue(messageValue, MESSAGE_COUNT_EXCEED_KEY); + if (messageThreshold > 0) { + rules.add(new MessageCountExceedRule(messageThreshold)); + log.debug("Created MessageCountExceedRule with threshold: {}", messageThreshold); + } else { + throw new IllegalArgumentException( + "Invalid message count threshold value: " + messageValue + ". Must be positive integer." + ); + } + } catch (Exception e) { + log.error("Failed to create MessageCountExceedRule: {}", e.getMessage()); + } + } + + return rules; + } + + /** + * Create a composite rule that requires ALL rules to be satisfied (AND logic). + * @param rules the list of rules to combine + * @return a composite rule, or null if the list is empty + */ + public static ActivationRule createCompositeRule(List rules) { + if (rules == null || rules.isEmpty()) { + return null; + } + + if (rules.size() == 1) { + return rules.get(0); + } + + return new CompositeActivationRule(rules); + } + + /** + * Parse an integer value from configuration, handling various input types. + * @param value the value to parse + * @param fieldName the field name for error reporting + * @return the parsed integer value + * @throws IllegalArgumentException if the value cannot be parsed + */ + private static int parseIntegerValue(Object value, String fieldName) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid integer value for " + fieldName + ": " + value); + } + } else { + throw new IllegalArgumentException("Unsupported value type for " + fieldName + ": " + value.getClass().getSimpleName()); + } + } + + /** + * Composite activation rule that implements AND logic for multiple rules. + */ + private static class CompositeActivationRule implements ActivationRule { + private final List rules; + + public CompositeActivationRule(List rules) { + this.rules = new ArrayList<>(rules); + } + + @Override + public boolean evaluate(ContextManagerContext context) { + // All rules must evaluate to true (AND logic) + for (ActivationRule rule : rules) { + if (!rule.evaluate(context)) { + return false; + } + } + return true; + } + + @Override + public String getDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("composite_rule: ["); + for (int i = 0; i < rules.size(); i++) { + if (i > 0) { + sb.append(" AND "); + } + sb.append(rules.get(i).getDescription()); + } + sb.append("]"); + return sb.toString(); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java new file mode 100644 index 0000000000..e9b87a20bc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.extern.log4j.Log4j2; + +/** + * Character-based token counter implementation. + * Uses a simple heuristic of approximately 4 characters per token. + * This is a fallback implementation when more sophisticated token counting is not available. + */ +@Log4j2 +public class CharacterBasedTokenCounter implements TokenCounter { + + private static final double CHARS_PER_TOKEN = 4.0; + + @Override + public int count(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + return (int) Math.ceil(text.length() / CHARS_PER_TOKEN); + } + + @Override + public String truncateFromEnd(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(0, maxChars); + } + + @Override + public String truncateFromBeginning(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(text.length() - maxChars); + } + + @Override + public String truncateMiddle(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + // Keep equal portions from beginning and end + int halfChars = maxChars / 2; + String beginning = text.substring(0, halfChars); + String end = text.substring(text.length() - halfChars); + + return beginning + end; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java new file mode 100644 index 0000000000..40969b8c9a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context Management Template defines which context managers to use and when. + * This class represents a registered configuration that can be applied to + * agent execution to enable dynamic context optimization. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder(toBuilder = true) +public class ContextManagementTemplate implements ToXContentObject, Writeable { + + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String HOOKS_FIELD = "hooks"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_MODIFIED_FIELD = "last_modified"; + public static final String CREATED_BY_FIELD = "created_by"; + + /** + * Unique name for the context management template + */ + private String name; + + /** + * Human-readable description of what this template does + */ + private String description; + + /** + * Map of hook names to lists of context manager configurations + */ + private Map> hooks; + + /** + * When this template was created + */ + private Instant createdTime; + + /** + * When this template was last modified + */ + private Instant lastModified; + + /** + * Who created this template + */ + private String createdBy; + + /** + * Constructor from StreamInput + */ + public ContextManagementTemplate(StreamInput input) throws IOException { + this.name = input.readString(); + this.description = input.readOptionalString(); + + // Read hooks map + int hooksSize = input.readInt(); + if (hooksSize > 0) { + this.hooks = input.readMap(StreamInput::readString, in -> { + try { + int listSize = in.readInt(); + List configs = new java.util.ArrayList<>(); + for (int i = 0; i < listSize; i++) { + configs.add(new ContextManagerConfig(in)); + } + return configs; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + this.createdTime = input.readOptionalInstant(); + this.lastModified = input.readOptionalInstant(); + this.createdBy = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeOptionalString(description); + + // Write hooks map + if (hooks != null) { + out.writeInt(hooks.size()); + out.writeMap(hooks, StreamOutput::writeString, (output, configs) -> { + try { + output.writeInt(configs.size()); + for (ContextManagerConfig config : configs) { + config.writeTo(output); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } else { + out.writeInt(0); + } + + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastModified); + out.writeOptionalString(createdBy); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (name != null) { + builder.field(NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (hooks != null && !hooks.isEmpty()) { + builder.field(HOOKS_FIELD, hooks); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastModified != null) { + builder.field(LAST_MODIFIED_FIELD, lastModified.toEpochMilli()); + } + if (createdBy != null) { + builder.field(CREATED_BY_FIELD, createdBy); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagementTemplate from XContentParser + */ + public static ContextManagementTemplate parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + Map> hooks = null; + Instant createdTime = null; + Instant lastModified = null; + String createdBy = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case HOOKS_FIELD: + hooks = parseHooks(parser); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_MODIFIED_FIELD: + lastModified = Instant.ofEpochMilli(parser.longValue()); + break; + case CREATED_BY_FIELD: + createdBy = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return ContextManagementTemplate + .builder() + .name(name) + .description(description) + .hooks(hooks) + .createdTime(createdTime) + .lastModified(lastModified) + .createdBy(createdBy) + .build(); + } + + /** + * Parse hooks configuration from XContentParser + */ + private static Map> parseHooks(XContentParser parser) throws IOException { + Map> hooks = new java.util.HashMap<>(); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String hookName = parser.currentName(); + parser.nextToken(); // Move to START_ARRAY + + List configs = new java.util.ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + configs.add(ContextManagerConfig.parse(parser)); + } + + hooks.put(hookName, configs); + } + + return hooks; + } + + /** + * Validate the template configuration + */ + public boolean isValid() { + if (name == null || name.trim().isEmpty()) { + return false; + } + + // Allow null hooks (no context management) but not empty hooks map (misconfiguration) + if (hooks != null) { + if (hooks.isEmpty()) { + return false; + } + + // Validate all context manager configs + for (List configs : hooks.values()) { + if (configs != null) { + for (ContextManagerConfig config : configs) { + if (!config.isValid()) { + return false; + } + } + } + } + } + + return true; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java new file mode 100644 index 0000000000..325f98900a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.Map; + +/** + * Base interface for all context managers. + * Context managers are pluggable components that inspect and transform + * agent context components before they are sent to an LLM. + */ +public interface ContextManager { + + /** + * Get the type identifier for this context manager + * @return String identifying the manager type + */ + String getType(); + + /** + * Initialize the context manager with configuration + * @param config Configuration map for the manager + */ + void initialize(Map config); + + /** + * Check if this context manager should activate based on current context + * @param context The current context manager context + * @return true if the manager should execute, false otherwise + */ + boolean shouldActivate(ContextManagerContext context); + + /** + * Execute the context transformation + * @param context The context manager context to transform + */ + void execute(ContextManagerContext context); + +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java new file mode 100644 index 0000000000..92755cb243 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Configuration for a context manager within a context management template. + * This class holds the configuration details for how a specific context manager + * should be configured and when it should activate. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerConfig implements ToXContentObject, Writeable { + + public static final String TYPE_FIELD = "type"; + public static final String ACTIVATION_FIELD = "activation"; + public static final String CONFIG_FIELD = "config"; + + /** + * The type of context manager (e.g., "ToolsOutputTruncateManager") + */ + private String type; + + /** + * Activation conditions that determine when this manager should execute + */ + private Map activation; + + /** + * Configuration parameters specific to this manager type + */ + private Map config; + + /** + * Constructor from StreamInput + */ + public ContextManagerConfig(StreamInput input) throws IOException { + this.type = input.readString(); + this.activation = input.readMap(); + this.config = input.readMap(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeMap(activation); + out.writeMap(config); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (activation != null && !activation.isEmpty()) { + builder.field(ACTIVATION_FIELD, activation); + } + if (config != null && !config.isEmpty()) { + builder.field(CONFIG_FIELD, config); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagerConfig from XContentParser + */ + public static ContextManagerConfig parse(XContentParser parser) throws IOException { + String type = null; + Map activation = null; + Map config = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case ACTIVATION_FIELD: + activation = parser.map(); + break; + case CONFIG_FIELD: + config = parser.map(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new ContextManagerConfig(type, activation, config); + } + + /** + * Validate the configuration + * @return true if configuration is valid, false otherwise + */ + public boolean isValid() { + return type != null && !type.trim().isEmpty(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java new file mode 100644 index 0000000000..9854b78dba --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context object that contains all components of the agent execution context. + * This object is passed to context managers for inspection and transformation. + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerContext { + + /** + * The invocation state from the hook system + */ + private Map invocationState; + + /** + * The system prompt for the LLM + */ + private String systemPrompt; + + /** + * The chat history as a list of interactions + */ + @Builder.Default + private List chatHistory = new ArrayList<>(); + + /** + * The current user prompt/input + */ + private String userPrompt; + + /** + * The tool configurations available to the agent + */ + @Builder.Default + private List toolConfigs = new ArrayList<>(); + + /** + * The tool interactions/results from tool executions + */ + @Builder.Default + private List toolInteractions = new ArrayList<>(); + + /** + * Additional parameters for context processing + */ + @Builder.Default + private Map parameters = new HashMap<>(); + + /** + * Get the total token count for the current context. + * This is a utility method that can be used by context managers. + * @return estimated token count + */ + public int getEstimatedTokenCount() { + int tokenCount = 0; + + // Estimate tokens for system prompt + if (systemPrompt != null) { + tokenCount += estimateTokens(systemPrompt); + } + + // Estimate tokens for user prompt + if (userPrompt != null) { + tokenCount += estimateTokens(userPrompt); + } + + // Estimate tokens for chat history + for (Interaction interaction : chatHistory) { + if (interaction.getInput() != null) { + tokenCount += estimateTokens(interaction.getInput()); + } + if (interaction.getResponse() != null) { + tokenCount += estimateTokens(interaction.getResponse()); + } + } + + // Estimate tokens for tool interactions + for (String interaction : toolInteractions) { + tokenCount += estimateTokens(interaction); + } + + return tokenCount; + } + + /** + * Get the message count in chat history. + * @return number of messages in chat history + */ + public int getMessageCount() { + return chatHistory.size(); + } + + /** + * Simple token estimation based on character count. + * This is a fallback method - more sophisticated token counting should be implemented + * in dedicated TokenCounter implementations. + * @param text the text to estimate tokens for + * @return estimated token count + */ + private int estimateTokens(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + // Rough estimation: 1 token per 4 characters + return (int) Math.ceil(text.length() / 4.0); + } + + /** + * Add a tool interaction to the context. + * @param interaction the tool interaction to add + */ + public void addToolInteraction(String interaction) { + if (toolInteractions == null) { + toolInteractions = new ArrayList<>(); + } + toolInteractions.add(interaction); + } + + /** + * Add an interaction to the chat history. + * @param interaction the interaction to add + */ + public void addChatHistoryInteraction(Interaction interaction) { + if (chatHistory == null) { + chatHistory = new ArrayList<>(); + } + chatHistory.add(interaction); + } + + /** + * Clear the chat history. + */ + public void clearChatHistory() { + if (chatHistory != null) { + chatHistory.clear(); + } + } + + /** + * Get a parameter value by key. + * @param key the parameter key + * @return the parameter value, or null if not found + */ + public Object getParameter(String key) { + return parameters != null ? parameters.get(key) : null; + } + + /** + * Set a parameter value. + * @param key the parameter key + * @param value the parameter value + */ + public void setParameter(String key, String value) { + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(key, value); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java new file mode 100644 index 0000000000..dfef018c87 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PostMemoryEvent; +import org.opensearch.ml.common.hooks.PreLLMEvent; + +import lombok.extern.log4j.Log4j2; + +/** + * Hook provider that integrates context managers with the hook registry. + * This class manages the execution of context managers based on hook events. + */ +@Log4j2 +public class ContextManagerHookProvider implements HookProvider { + private final List contextManagers; + private final Map> hookToManagersMap; + + /** + * Constructor for ContextManagerHookProvider + * @param contextManagers List of context managers to register + */ + public ContextManagerHookProvider(List contextManagers) { + this.contextManagers = new ArrayList<>(contextManagers); + this.hookToManagersMap = new HashMap<>(); + + // Group managers by hook type based on their configuration + // This would typically be done based on the template configuration + // For now, we'll organize them by common hook types + organizeManagersByHook(); + } + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + @Override + public void registerHooks(HookRegistry registry) { + // Only register callbacks for hooks that have managers configured + if (hookToManagersMap.containsKey("PRE_LLM")) { + registry.addCallback(PreLLMEvent.class, this::handlePreLLM); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registry.addCallback(EnhancedPostToolEvent.class, this::handlePostTool); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registry.addCallback(PostMemoryEvent.class, this::handlePostMemory); + } + + log.info("Registered context manager hooks for {} managers", contextManagers.size()); + } + + /** + * Handle PreLLM hook events + * @param event The PreLLM event + */ + private void handlePreLLM(PreLLMEvent event) { + log.debug("Handling PreLLM event"); + executeManagersForHook("PRE_LLM", event.getContext()); + } + + /** + * Handle PostTool hook events + * @param event The EnhancedPostTool event + */ + private void handlePostTool(EnhancedPostToolEvent event) { + log.debug("Handling PostTool event"); + executeManagersForHook("POST_TOOL", event.getContext()); + } + + /** + * Handle PostMemory hook events + * @param event The PostMemory event + */ + private void handlePostMemory(PostMemoryEvent event) { + log.debug("Handling PostMemory event"); + executeManagersForHook("POST_MEMORY", event.getContext()); + } + + /** + * Execute context managers for a specific hook + * @param hookName The name of the hook + * @param context The context manager context + */ + private void executeManagersForHook(String hookName, ContextManagerContext context) { + List managers = hookToManagersMap.get(hookName); + if (managers != null && !managers.isEmpty()) { + log.debug("Executing {} context managers for hook: {}", managers.size(), hookName); + + for (ContextManager manager : managers) { + try { + if (manager.shouldActivate(context)) { + log.debug("Executing context manager: {}", manager.getType()); + manager.execute(context); + log.debug("Successfully executed context manager: {}", manager.getType()); + } else { + log.debug("Context manager {} activation conditions not met, skipping", manager.getType()); + } + } catch (Exception e) { + log.error("Context manager {} failed: {}", manager.getType(), e.getMessage(), e); + // Continue with other managers even if one fails + } + } + } else { + log.debug("No context managers registered for hook: {}", hookName); + } + } + + /** + * Organize managers by hook type + * This is a simplified implementation - in practice, this would be based on + * the context management template configuration + */ + private void organizeManagersByHook() { + // For now, we'll assign managers to hooks based on their type + // This would be replaced with actual template-based configuration + for (ContextManager manager : contextManagers) { + String managerType = manager.getType(); + + // Assign managers to appropriate hooks based on their type + if ("ToolsOutputTruncateManager".equals(managerType)) { + addManagerToHook("POST_TOOL", manager); + } else if ("SlidingWindowManager".equals(managerType) || "SummarizingManager".equals(managerType)) { + addManagerToHook("POST_MEMORY", manager); + addManagerToHook("PRE_LLM", manager); + } else if ("SystemPromptAugmentationManager".equals(managerType)) { + addManagerToHook("PRE_LLM", manager); + } else { + // Default to PRE_LLM for unknown types + addManagerToHook("PRE_LLM", manager); + } + } + } + + /** + * Add a manager to a specific hook + * @param hookName The hook name + * @param manager The context manager + */ + private void addManagerToHook(String hookName, ContextManager manager) { + hookToManagersMap.computeIfAbsent(hookName, k -> new ArrayList<>()).add(manager); + log.debug("Added manager {} to hook {}", manager.getType(), hookName); + } + + /** + * Update the hook-to-managers mapping based on template configuration + * @param hookConfiguration Map of hook names to manager configurations + */ + public void updateHookConfiguration(Map> hookConfiguration) { + hookToManagersMap.clear(); + + for (Map.Entry> entry : hookConfiguration.entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + // Find the corresponding context manager + ContextManager manager = findManagerByType(config.getType()); + if (manager != null) { + addManagerToHook(hookName, manager); + } else { + log.warn("Context manager of type {} not found", config.getType()); + } + } + } + + log.info("Updated hook configuration with {} hooks", hookConfiguration.size()); + } + + /** + * Find a context manager by its type + * @param type The manager type + * @return The context manager or null if not found + */ + private ContextManager findManagerByType(String type) { + return contextManagers.stream().filter(manager -> type.equals(manager.getType())).findFirst().orElse(null); + } + + /** + * Get the number of managers registered for a specific hook + * @param hookName The hook name + * @return Number of managers + */ + public int getManagerCount(String hookName) { + List managers = hookToManagersMap.get(hookName); + return managers != null ? managers.size() : 0; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java new file mode 100644 index 0000000000..f3a4d5c57e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the chat history message count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class MessageCountExceedRule implements ActivationRule { + + private final int messageThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentMessageCount = context.getMessageCount(); + return currentMessageCount > messageThreshold; + } + + @Override + public String getDescription() { + return "message_count_exceed: " + messageThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java new file mode 100644 index 0000000000..42bbd813ee --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for counting and truncating tokens in text. + * Provides methods for accurate token counting and various truncation strategies. + */ +public interface TokenCounter { + + /** + * Count the number of tokens in the given text. + * @param text the text to count tokens for + * @return the number of tokens + */ + int count(String text); + + /** + * Truncate text from the end to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromEnd(String text, int maxTokens); + + /** + * Truncate text from the beginning to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromBeginning(String text, int maxTokens); + + /** + * Truncate text from the middle to fit within the specified token limit. + * Preserves both beginning and end portions of the text. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateMiddle(String text, int maxTokens); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java new file mode 100644 index 0000000000..e4bc0544f0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the context token count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class TokensExceedRule implements ActivationRule { + + private final int tokenThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentTokenCount = context.getEstimatedTokenCount(); + return currentTokenCount > tokenThreshold; + } + + @Override + public String getDescription() { + return "tokens_exceed: " + tokenThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java new file mode 100644 index 0000000000..b8d8cb5cc4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Context management framework for OpenSearch ML-Commons. + * + * This package provides a pluggable context management system that allows for dynamic + * optimization of LLM context windows through configurable context managers. + * + * Key components: + * - {@link org.opensearch.ml.common.contextmanager.ContextManager}: Base interface for all context managers + * - {@link org.opensearch.ml.common.contextmanager.ContextManagerContext}: Context object containing all agent execution state + * - {@link org.opensearch.ml.common.contextmanager.ActivationRule}: Interface for rules that determine when managers should execute + * - {@link org.opensearch.ml.common.contextmanager.ActivationRuleFactory}: Factory for creating activation rules from configuration + * + * The system integrates with the existing hook framework to provide seamless context optimization + * during agent execution without breaking existing functionality. + */ +package org.opensearch.ml.common.contextmanager; diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java new file mode 100644 index 0000000000..7db6341c9e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Enhanced version of PostToolEvent that includes context manager context. + * This event is triggered after tool execution and provides access to both + * tool results and the full context, allowing context managers to modify + * tool outputs and other context components. + */ +public class EnhancedPostToolEvent extends PostToolEvent { + private final ContextManagerContext context; + + /** + * Constructor for EnhancedPostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public EnhancedPostToolEvent( + List> toolResults, + Exception error, + ContextManagerContext context, + Map invocationState + ) { + super(toolResults, error, invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java new file mode 100644 index 0000000000..13e7299e01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Functional interface for handling specific hook events. + * Implementations of this interface define the behavior to execute + * when a particular hook event is triggered. + * + * @param The type of HookEvent this callback handles + */ +@FunctionalInterface +public interface HookCallback { + + /** + * Handle the hook event + * @param event The hook event to handle + */ + void handle(T event); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java new file mode 100644 index 0000000000..c7f1503b61 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +/** + * Base class for all hook events in the ML agent lifecycle. + * Hook events are strongly-typed events that carry context information + * for different stages of agent execution. + */ +public abstract class HookEvent { + private final Map invocationState; + + /** + * Constructor for HookEvent + * @param invocationState The current state of the agent invocation + */ + protected HookEvent(Map invocationState) { + this.invocationState = invocationState; + } + + /** + * Get the invocation state + * @return Map containing the current invocation state + */ + public Map getInvocationState() { + return invocationState; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java new file mode 100644 index 0000000000..d6612f6749 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Interface for providers that register hook callbacks with the HookRegistry. + * Implementations of this interface define which hooks they want to listen to + * and provide the callback implementations. + */ +public interface HookProvider { + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + void registerHooks(HookRegistry registry); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java new file mode 100644 index 0000000000..32076d0d78 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import lombok.extern.log4j.Log4j2; + +/** + * Registry for managing hook callbacks and event emission. + * This class manages the registration of callbacks for different hook event types + * and provides methods to emit events to registered callbacks. + */ +@Log4j2 +public class HookRegistry { + private final Map, List>> callbacks; + + /** + * Constructor for HookRegistry + */ + public HookRegistry() { + this.callbacks = new ConcurrentHashMap<>(); + } + + /** + * Add a callback for a specific hook event type + * @param eventType The class of the hook event + * @param callback The callback to execute when the event is emitted + * @param The type of hook event + */ + public void addCallback(Class eventType, HookCallback callback) { + callbacks.computeIfAbsent(eventType, k -> new ArrayList<>()).add(callback); + log.debug("Registered callback for event type: {}", eventType.getSimpleName()); + } + + /** + * Emit an event to all registered callbacks for that event type + * @param event The hook event to emit + * @param The type of hook event + */ + @SuppressWarnings("unchecked") + public void emit(T event) { + Class eventType = event.getClass(); + List> eventCallbacks = callbacks.get(eventType); + + log + .info( + "HookRegistry.emit() called for event type: {}, callbacks available: {}", + eventType.getSimpleName(), + eventCallbacks != null ? eventCallbacks.size() : 0 + ); + + if (eventCallbacks != null) { + log.info("Emitting {} event to {} callbacks", eventType.getSimpleName(), eventCallbacks.size()); + + for (HookCallback callback : eventCallbacks) { + try { + log.info("Executing callback: {}", callback.getClass().getSimpleName()); + ((HookCallback) callback).handle(event); + } catch (Exception e) { + log.error("Error executing hook callback for event type {}: {}", eventType.getSimpleName(), e.getMessage(), e); + // Continue with other callbacks even if one fails + } + } + } else { + log.warn("No callbacks registered for event type: {}", eventType.getSimpleName()); + } + } + + /** + * Get the number of registered callbacks for a specific event type + * @param eventType The class of the hook event + * @return Number of registered callbacks + */ + public int getCallbackCount(Class eventType) { + List> eventCallbacks = callbacks.get(eventType); + return eventCallbacks != null ? eventCallbacks.size() : 0; + } + + /** + * Clear all registered callbacks + */ + public void clear() { + callbacks.clear(); + log.debug("Cleared all hook callbacks"); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java new file mode 100644 index 0000000000..006f6e8069 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.conversation.Interaction; + +/** + * Hook event triggered after memory retrieval in the agent lifecycle. + * This event provides access to the retrieved chat history and context, + * allowing context managers to modify the memory before it's used. + */ +public class PostMemoryEvent extends HookEvent { + private final ContextManagerContext context; + private final List retrievedHistory; + + /** + * Constructor for PostMemoryEvent + * @param context The context manager context containing all context components + * @param retrievedHistory The chat history retrieved from memory + * @param invocationState The current state of the agent invocation + */ + public PostMemoryEvent(ContextManagerContext context, List retrievedHistory, Map invocationState) { + super(invocationState); + this.context = context; + this.retrievedHistory = retrievedHistory; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } + + /** + * Get the retrieved chat history + * @return List of interactions retrieved from memory + */ + public List getRetrievedHistory() { + return retrievedHistory; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java new file mode 100644 index 0000000000..609d6028da --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +/** + * Hook event triggered after tool execution in the agent lifecycle. + * This event provides access to tool results and any errors that occurred. + */ +public class PostToolEvent extends HookEvent { + private final List> toolResults; + private final Exception error; + + /** + * Constructor for PostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param invocationState The current state of the agent invocation + */ + public PostToolEvent(List> toolResults, Exception error, Map invocationState) { + super(invocationState); + this.toolResults = toolResults; + this.error = error; + } + + /** + * Get the tool execution results + * @return List of tool results + */ + public List> getToolResults() { + return toolResults; + } + + /** + * Get the error that occurred during tool execution + * @return Exception if an error occurred, null otherwise + */ + public Exception getError() { + return error; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java new file mode 100644 index 0000000000..42e0cc1d0c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.input.Input; + +public class PreInvocationEvent extends HookEvent { + private final Input input; + + public PreInvocationEvent(Input input, Map invocationState) { + super(invocationState); + this.input = input; + } + + public Input getInput() { + return input; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java new file mode 100644 index 0000000000..1b82b04512 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Hook event triggered before LLM invocation in the agent lifecycle. + * This event provides access to the context that will be sent to the LLM, + * allowing context managers to modify it before the LLM call. + */ +public class PreLLMEvent extends HookEvent { + private final ContextManagerContext context; + + /** + * Constructor for PreLLMEvent + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public PreLLMEvent(ContextManagerContext context, Map invocationState) { + super(invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index 986d6eefef..c92ca43846 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -47,6 +48,14 @@ public class AgentMLInput extends MLInput { @Setter private Boolean isAsync; + @Getter + @Setter + private HookRegistry hookRegistry; + + @Getter + @Setter + private String contextManagementName; + @Builder(builderMethodName = "AgentMLInputBuilder") public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) { this(agentId, tenantId, functionName, inputDataset, false); @@ -72,6 +81,7 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) { out.writeOptionalBoolean(isAsync); } + // Note: contextManagementName and hookRegistry are not serialized as they are runtime-only fields } public AgentMLInput(StreamInput in) throws IOException { @@ -100,7 +110,12 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE tenantId = parser.textOrNull(); break; case PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); + Map parameterObjs = parser.map(); + Map parameters = StringUtils.getParameterMap(parameterObjs); + // Extract context_management from parameters + if (parameterObjs.containsKey("context_management")) { + contextManagementName = (String) parameterObjs.get("context_management"); + } inputDataset = new RemoteInferenceInputDataSet(parameters); break; case ASYNC_FIELD: diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index 9b1b6c6a81..b32e6471d7 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -7,6 +7,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -125,13 +126,16 @@ public Boolean validate(String in, Map parameters) { guardrailModelParams.put("response_filter", responseFilter); } log.info("Guardrail resFilter: {}", responseFilter); + String tenantId = parameters != null ? parameters.get(TENANT_ID_FIELD) : null; ActionRequest request = new MLPredictionTaskRequest( modelId, RemoteInferenceMLInput .builder() .algorithm(FunctionName.REMOTE) .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) - .build() + .build(), + null, + tenantId ); client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch)); try { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java index c73f2150aa..90096044f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -48,11 +48,62 @@ public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; if (mlAgent == null) { exception = addValidationError("ML agent can't be null", exception); + } else { + // Basic validation - check for conflicting configuration (following connector pattern) + if (mlAgent.getContextManagementName() != null && mlAgent.getContextManagement() != null) { + exception = addValidationError("Cannot specify both context_management_name and context_management", exception); + } + + // Validate context management template name + if (mlAgent.getContextManagementName() != null) { + exception = validateContextManagementTemplateName(mlAgent.getContextManagementName(), exception); + } + + // Validate inline context management configuration + if (mlAgent.getContextManagement() != null) { + exception = validateInlineContextManagement(mlAgent.getContextManagement(), exception); + } + } + + return exception; + } + + private ActionRequestValidationException validateContextManagementTemplateName( + String templateName, + ActionRequestValidationException exception + ) { + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Context management template name cannot be null or empty", exception); + } else if (templateName.length() > 256) { + exception = addValidationError("Context management template name cannot exceed 256 characters", exception); + } else if (!templateName.matches("^[a-zA-Z0-9._-]+$")) { + exception = addValidationError( + "Context management template name can only contain letters, numbers, underscores, hyphens, and dots", + exception + ); } + return exception; + } + private ActionRequestValidationException validateInlineContextManagement( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement, + ActionRequestValidationException exception + ) { + if (contextManagement.getHooks() != null) { + for (String hookName : contextManagement.getHooks().keySet()) { + if (!isValidHookName(hookName)) { + exception = addValidationError("Invalid hook name: " + hookName, exception); + } + } + } return exception; } + private boolean isValidHookName(String hookName) { + // Define valid hook names based on the system's supported hooks + return hookName.equals("POST_TOOL") || hookName.equals("PRE_LLM") || hookName.equals("PRE_TOOL") || hookName.equals("POST_LLM"); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..b6116afa4f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLCreateContextManagementTemplateAction extends ActionType { + public static MLCreateContextManagementTemplateAction INSTANCE = new MLCreateContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/create"; + + private MLCreateContextManagementTemplateAction() { + super(NAME, MLCreateContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java new file mode 100644 index 0000000000..ee98607505 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLCreateContextManagementTemplateRequest extends ActionRequest { + + String templateName; + ContextManagementTemplate template; + + @Builder + public MLCreateContextManagementTemplateRequest(String templateName, ContextManagementTemplate template) { + this.templateName = templateName; + this.template = template; + } + + public MLCreateContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.template = new ContextManagementTemplate(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + if (template == null) { + exception = addValidationError("Context management template cannot be null", exception); + } else { + // Validate template structure + if (template.getName() == null || template.getName().trim().isEmpty()) { + exception = addValidationError("Template name in body cannot be null or empty", exception); + } + if (template.getHooks() == null || template.getHooks().isEmpty()) { + exception = addValidationError("Template must define at least one hook", exception); + } + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + template.writeTo(out); + } + + public static MLCreateContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateContextManagementTemplateRequest) { + return (MLCreateContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java new file mode 100644 index 0000000000..85265bb333 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLCreateContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLCreateContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLCreateContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLCreateContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateContextManagementTemplateResponse) { + return (MLCreateContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..6074891afa --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLDeleteContextManagementTemplateAction extends ActionType { + public static MLDeleteContextManagementTemplateAction INSTANCE = new MLDeleteContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/delete"; + + private MLDeleteContextManagementTemplateAction() { + super(NAME, MLDeleteContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java new file mode 100644 index 0000000000..e7b6e69200 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLDeleteContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLDeleteContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLDeleteContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLDeleteContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLDeleteContextManagementTemplateRequest) { + return (MLDeleteContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLDeleteContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java new file mode 100644 index 0000000000..415323f932 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLDeleteContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLDeleteContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLDeleteContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLDeleteContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeleteContextManagementTemplateResponse) { + return (MLDeleteContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLDeleteContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4220dafe25 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLGetContextManagementTemplateAction extends ActionType { + public static MLGetContextManagementTemplateAction INSTANCE = new MLGetContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/get"; + + private MLGetContextManagementTemplateAction() { + super(NAME, MLGetContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java new file mode 100644 index 0000000000..f8f8061868 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLGetContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLGetContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLGetContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLGetContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLGetContextManagementTemplateRequest) { + return (MLGetContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLGetContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java new file mode 100644 index 0000000000..309d4c88af --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLGetContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + + private ContextManagementTemplate template; + + public MLGetContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.template = new ContextManagementTemplate(in); + } + + public MLGetContextManagementTemplateResponse(ContextManagementTemplate template) { + this.template = template; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + template.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return template.toXContent(builder, params); + } + + public static MLGetContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLGetContextManagementTemplateResponse) { + return (MLGetContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLGetContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..2b18f92e20 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLListContextManagementTemplatesAction extends ActionType { + public static MLListContextManagementTemplatesAction INSTANCE = new MLListContextManagementTemplatesAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/list"; + + private MLListContextManagementTemplatesAction() { + super(NAME, MLListContextManagementTemplatesResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java new file mode 100644 index 0000000000..7f86ad63f6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLListContextManagementTemplatesRequest extends ActionRequest { + + int from; + int size; + + @Builder + public MLListContextManagementTemplatesRequest(int from, int size) { + this.from = from; + this.size = size; + } + + public MLListContextManagementTemplatesRequest(StreamInput in) throws IOException { + super(in); + this.from = in.readInt(); + this.size = in.readInt(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + // No specific validation needed for list request + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(from); + out.writeInt(size); + } + + public static MLListContextManagementTemplatesRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLListContextManagementTemplatesRequest) { + return (MLListContextManagementTemplatesRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLListContextManagementTemplatesRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java new file mode 100644 index 0000000000..bc66395100 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLListContextManagementTemplatesResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATES_FIELD = "templates"; + public static final String TOTAL_FIELD = "total"; + + private List templates; + private long total; + + public MLListContextManagementTemplatesResponse(StreamInput in) throws IOException { + super(in); + this.templates = in.readList(ContextManagementTemplate::new); + this.total = in.readLong(); + } + + public MLListContextManagementTemplatesResponse(List templates, long total) { + this.templates = templates; + this.total = total; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(templates); + out.writeLong(total); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOTAL_FIELD, total); + builder.startArray(TEMPLATES_FIELD); + for (ContextManagementTemplate template : templates) { + template.toXContent(builder, params); + } + builder.endArray(); + builder.endObject(); + return builder; + } + + public static MLListContextManagementTemplatesResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLListContextManagementTemplatesResponse) { + return (MLListContextManagementTemplatesResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLListContextManagementTemplatesResponse", e); + } + } +} diff --git a/common/src/main/resources/index-mappings/ml_agent.json b/common/src/main/resources/index-mappings/ml_agent.json index 9d4deeca51..c530711fb2 100644 --- a/common/src/main/resources/index-mappings/ml_agent.json +++ b/common/src/main/resources/index-mappings/ml_agent.json @@ -43,6 +43,24 @@ "last_updated_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "context_management_name": { + "type": "keyword" + }, + "context_management": { + "type": "object", + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + } + } } } } diff --git a/common/src/main/resources/index-mappings/ml_context_management_templates.json b/common/src/main/resources/index-mappings/ml_context_management_templates.json new file mode 100644 index 0000000000..534be6702d --- /dev/null +++ b/common/src/main/resources/index-mappings/ml_context_management_templates.json @@ -0,0 +1,26 @@ +{ + "dynamic": false, + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + }, + "created_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "last_modified": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "created_by": { + "type": "keyword" + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index da2f5f5c1e..cf0747603e 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -29,6 +29,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.search.SearchModule; public class MLAgentTest { @@ -65,6 +66,8 @@ public void constructor_NullName() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -86,6 +89,8 @@ public void constructor_NullType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -107,6 +112,8 @@ public void constructor_NullLLMSpec() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -128,6 +135,8 @@ public void constructor_DuplicateTool() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -146,6 +155,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -174,6 +185,8 @@ public void writeTo_NullLLM() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -197,6 +210,8 @@ public void writeTo_NullTools() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -220,6 +235,8 @@ public void writeTo_NullParameters() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -243,6 +260,8 @@ public void writeTo_NullMemory() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -279,6 +298,8 @@ public void toXContent() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -336,6 +357,8 @@ public void fromStream() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -367,6 +390,8 @@ public void constructor_InvalidAgentType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -386,6 +411,8 @@ public void constructor_NonConversationalNoLLM() { Instant.EPOCH, "test", false, + null, + null, null ); assertNotNull(agent); // Ensuring object creation was successful without throwing an exception @@ -396,7 +423,22 @@ public void constructor_NonConversationalNoLLM() { @Test public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOException { - MLAgent agent = new MLAgent("test", "FLOW", "test", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", true, null); + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + true, + null, + null, + null + ); // Serialize and deserialize with an older version BytesStreamOutput output = new BytesStreamOutput(); @@ -460,6 +502,8 @@ public void getTags() { Instant.EPOCH, "test_app", true, + null, + null, null ); @@ -486,6 +530,8 @@ public void getTags_NullValues() { Instant.EPOCH, "test_app", null, + null, + null, null ); @@ -497,4 +543,325 @@ public void getTags_NullValues() { assertFalse(tagsMap.containsKey("memory_type")); assertFalse(tagsMap.containsKey("_llm_interface")); } + + @Test + public void constructor_ConflictingContextManagement() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Cannot specify both context_management_name and context_management"); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + new ContextManagementTemplate(), + null + ); + } + + @Test + public void hasContextManagement_WithTemplateName() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_WithInlineConfig() { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertEquals(template, agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_NoContextManagement() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + null, + null + ); + + assertFalse(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertEquals("template_name", deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertTrue(deserializedAgent.hasContextManagement()); + assertTrue(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementInline() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNotNull(deserializedAgent.getInlineContextManagement()); + assertEquals("test_template", deserializedAgent.getInlineContextManagement().getName()); + assertEquals("test description", deserializedAgent.getInlineContextManagement().getDescription()); + assertTrue(deserializedAgent.hasContextManagement()); + assertFalse(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagement_VersionCompatibility() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + // Serialize with older version (before context management support) + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_2_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_2_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + // Context management fields should be null for older versions + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertFalse(deserializedAgent.hasContextManagement()); + } + + @Test + public void parse_WithContextManagementName() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management_name\":\"template_name\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + } + + @Test + public void parse_WithInlineContextManagement() throws IOException { + String jsonStr = + "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management\":{\"name\":\"inline_template\",\"description\":\"test\",\"hooks\":{\"POST_TOOL\":[]}}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertNull(agent.getContextManagementTemplateName()); + assertNotNull(agent.getInlineContextManagement()); + assertEquals("inline_template", agent.getInlineContextManagement().getName()); + assertEquals("test", agent.getInlineContextManagement().getDescription()); + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + } + + @Test + public void toXContent_WithContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertTrue(content.contains("\"context_management_name\":\"template_name\"")); + assertFalse(content.contains("\"context_management\":")); + } + + @Test + public void toXContent_WithInlineContextManagement() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("inline_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertFalse(content.contains("\"context_management_name\":")); + assertTrue(content.contains("\"context_management\":")); + assertTrue(content.contains("\"inline_template\"")); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java new file mode 100644 index 0000000000..8eb5d7978f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for CharacterBasedTokenCounter. + */ +public class CharacterBasedTokenCounterTest { + + private CharacterBasedTokenCounter tokenCounter; + + @Before + public void setUp() { + tokenCounter = new CharacterBasedTokenCounter(); + } + + @Test + public void testCountWithNullText() { + Assert.assertEquals(0, tokenCounter.count(null)); + } + + @Test + public void testCountWithEmptyText() { + Assert.assertEquals(0, tokenCounter.count("")); + } + + @Test + public void testCountWithShortText() { + String text = "Hi"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithMediumText() { + String text = "This is a test message"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithLongText() { + String text = "This is a very long text that should result in multiple tokens when counted using the character-based approach."; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testTruncateFromEndWithNullText() { + Assert.assertNull(tokenCounter.truncateFromEnd(null, 10)); + } + + @Test + public void testTruncateFromEndWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromEnd("", 10)); + } + + @Test + public void testTruncateFromEndWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromEnd(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromEndWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromEnd(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.startsWith(result)); + } + + @Test + public void testTruncateFromBeginningWithNullText() { + Assert.assertNull(tokenCounter.truncateFromBeginning(null, 10)); + } + + @Test + public void testTruncateFromBeginningWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromBeginning("", 10)); + } + + @Test + public void testTruncateFromBeginningWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromBeginning(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromBeginningWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromBeginning(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.endsWith(result)); + } + + @Test + public void testTruncateMiddleWithNullText() { + Assert.assertNull(tokenCounter.truncateMiddle(null, 10)); + } + + @Test + public void testTruncateMiddleWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateMiddle("", 10)); + } + + @Test + public void testTruncateMiddleWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateMiddle(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateMiddleWithLongText() { + String text = "This is a very long text that needs to be truncated from the middle"; + String result = tokenCounter.truncateMiddle(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + + // Result should contain parts from both beginning and end + int halfChars = (5 * 4) / 2; + String expectedBeginning = text.substring(0, halfChars); + String expectedEnd = text.substring(text.length() - halfChars); + + Assert.assertTrue(result.startsWith(expectedBeginning)); + Assert.assertTrue(result.endsWith(expectedEnd)); + } + + @Test + public void testTruncateConsistency() { + String text = "This is a test text for truncation consistency"; + int maxTokens = 3; + + String fromEnd = tokenCounter.truncateFromEnd(text, maxTokens); + String fromBeginning = tokenCounter.truncateFromBeginning(text, maxTokens); + String fromMiddle = tokenCounter.truncateMiddle(text, maxTokens); + + // All truncated results should have similar token counts + int tokensFromEnd = tokenCounter.count(fromEnd); + int tokensFromBeginning = tokenCounter.count(fromBeginning); + int tokensFromMiddle = tokenCounter.count(fromMiddle); + + Assert.assertTrue(tokensFromEnd <= maxTokens); + Assert.assertTrue(tokensFromBeginning <= maxTokens); + Assert.assertTrue(tokensFromMiddle <= maxTokens); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 81a173dfde..58921f71b8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -96,6 +96,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); @@ -115,7 +117,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null); + mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null, null, null); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java index 94cbbeb7dd..da7f3d5623 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -10,6 +10,9 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -21,6 +24,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; public class MLRegisterAgentRequestTest { @@ -111,4 +116,260 @@ public void writeTo(StreamOutput out) throws IOException { }; MLRegisterAgentRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_ContextManagementConflict() { + // Create agent with both context management name and inline configuration + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + // This should throw an exception during MLAgent construction + try { + MLAgent agentWithConflict = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Cannot specify both context_management_name and context_management")); + } + } + + @Test + public void validate_ContextManagementTemplateName_Valid() { + MLAgent agentWithTemplateName = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("valid_template_name") + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithTemplateName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Empty() { + // Test empty template name - this should be caught at request validation level + MLAgent agentWithEmptyName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithEmptyName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot be null or empty")); + } + + @Test + public void validate_ContextManagementTemplateName_TooLong() { + // Test template name that's too long + String longName = "a".repeat(257); + MLAgent agentWithLongName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithLongName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void validate_ContextManagementTemplateName_InvalidCharacters() { + // Test template name with invalid characters + MLAgent agentWithInvalidName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue( + exception + .toString() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + } + + @Test + public void validate_InlineContextManagement_Valid() { + ContextManagementTemplate validContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + MLAgent agentWithInlineConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(validContextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInlineConfig); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_InvalidHookName() { + // Create a context management template with invalid hook name but valid structure + // This should pass MLAgent validation but fail request validation + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(invalidHooks) + .build(); + + MLAgent agentWithInvalidConfig = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(invalidContextManagement) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidConfig); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void validate_InlineContextManagement_EmptyHooks() { + ContextManagementTemplate emptyHooksTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(new HashMap<>()) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithEmptyHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(emptyHooksTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_InlineContextManagement_InvalidContextManagerConfig() { + Map> hooksWithInvalidConfig = new HashMap<>(); + hooksWithInvalidConfig + .put( + "POST_TOOL", + Arrays + .asList( + new ContextManagerConfig(null, null, null) // Invalid: null type + ) + ); + + ContextManagementTemplate invalidTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(hooksWithInvalidConfig) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithInvalidConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_NoContextManagement_Valid() { + MLAgent agentWithoutContextManagement = MLAgent.builder().name("test_agent").type("flow").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithoutContextManagement); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_NullValue() { + // Test null template name - this should pass validation since null is acceptable + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Null() { + // Test null template name validation + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + // This should pass since null is handled differently than empty + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_NullHooks() { + // Test inline context management with null hooks + ContextManagementTemplate contextManagementWithNullHooks = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(null) + .build(); + + MLAgent agentWithNullHooks = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(contextManagementWithNullHooks) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullHooks); + ActionRequestValidationException exception = request.validate(); + + // Should pass since null hooks are handled gracefully + assertNull(exception); + } + + @Test + public void validate_HookName_AllValidTypes() { + // Test all valid hook names to improve branch coverage + Map> allValidHooks = new HashMap<>(); + allValidHooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + allValidHooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + allValidHooks.put("PRE_TOOL", Arrays.asList(new ContextManagerConfig("MemoryManager", null, null))); + allValidHooks.put("POST_LLM", Arrays.asList(new ContextManagerConfig("ConversationManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(allValidHooks) + .build(); + + MLAgent agentWithAllHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithAllHooks); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + /** + * Helper method to create valid hooks configuration for testing + */ + private Map> createValidHooks() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + hooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + return hooks; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java new file mode 100644 index 0000000000..da2fd985f3 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -0,0 +1,179 @@ +package org.opensearch.ml.engine.agents; + +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.SYSTEM_PROMPT_FIELD; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PreLLMEvent; +import org.opensearch.ml.common.spi.memory.Memory; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; + +public class AgentContextUtil { + private static final Logger log = LogManager.getLogger(AgentContextUtil.class); + + public static ContextManagerContext buildContextManagerContextForToolOutput( + String toolOutput, + Map parameters, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + contextParameters.put("_current_tool_output", toolOutput); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object extractProcessedToolOutput(ContextManagerContext context) { + if (context.getParameters() != null) { + return context.getParameters().get("_current_tool_output"); + } + return null; + } + + public static Object extractFromContext(ContextManagerContext context, String key) { + if (context.getParameters() != null) { + return context.getParameters().get(key); + } + return null; + } + + public static ContextManagerContext buildContextManagerContext( + Map parameters, + List interactions, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (memory instanceof ConversationIndexMemory) { + String chatHistory = parameters.get(CHAT_HISTORY); + // TODO to add chatHistory into context, currently there is no context manager working on chat_history + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + builder.toolInteractions(interactions != null ? interactions : new ArrayList<>()); + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static ContextManagerContext emitPostToolHook( + Object toolOutput, + Map parameters, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContextForToolOutput( + StringUtils.toJson(toolOutput), + parameters, + toolSpecs, + memory + ); + + if (hookRegistry != null) { + try { + if (toolOutput == null) { + log.warn("Tool output is null, skipping POST_TOOL hook"); + return context; + } + EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); + hookRegistry.emit(event); + } catch (Exception e) { + log.error("Failed to emit POST_TOOL hook event", e); + } + } + return context; + } + + public static ContextManagerContext emitPreLLMHook( + Map parameters, + List interactions, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory); + + try { + PreLLMEvent event = new PreLLMEvent(context, new HashMap<>()); + hookRegistry.emit(event); + return context; + + } catch (Exception e) { + log.error("Failed to emit PRE_LLM hook event", e); + return context; + } + } + + public static void updateParametersFromContext(Map parameters, ContextManagerContext context) { + if (context.getSystemPrompt() != null) { + parameters.put(SYSTEM_PROMPT_FIELD, context.getSystemPrompt()); + } + + if (context.getUserPrompt() != null) { + parameters.put(QUESTION, context.getUserPrompt()); + } + + if (context.getChatHistory() != null && !context.getChatHistory().isEmpty()) { + } + + if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) { + parameters.put(INTERACTIONS, ", " + String.join(", ", context.getToolInteractions())); + } + + if (context.getParameters() != null) { + for (Map.Entry entry : context.getParameters().entrySet()) { + parameters.put(entry.getKey(), entry.getValue()); + + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..451f19c9ea 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -51,7 +51,9 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.MLTaskOutput; @@ -64,6 +66,9 @@ import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.Executable; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -204,6 +209,12 @@ public void execute(Input input, ActionListener listener, TransportChann ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLAgent mlAgent = MLAgent.parse(parser); + // Use existing HookRegistry from AgentMLInput if available (set by MLExecuteTaskRunner for template + // references) + // Otherwise create a fresh HookRegistry for agent execution + final HookRegistry hookRegistry = agentMLInput.getHookRegistry() != null + ? agentMLInput.getHookRegistry() + : new HookRegistry(); if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { listener .onFailure( @@ -270,7 +281,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); }, e -> { log.error("Failed to get existing interaction for regeneration", e); @@ -287,7 +299,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); } }, ex -> { @@ -300,7 +313,8 @@ public void execute(Input input, ActionListener listener, TransportChann ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap .get(memorySpec.getType()); if (factory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can + // memoryId exists, so create returns an object with existing + // memory, therefore name can // be null factory .create( @@ -319,7 +333,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, createdMemory, - channel + channel, + hookRegistry ), ex -> { log.error("Failed to find memory with memory_id: {}", memoryId, ex); @@ -340,7 +355,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, null, - channel + channel, + hookRegistry ); } } catch (Exception e) { @@ -370,10 +386,11 @@ public void execute(Input input, ActionListener listener, TransportChann /** * save root interaction and start execute the agent - * @param listener callback listener - * @param memory memory instance + * + * @param listener callback listener + * @param memory memory instance * @param inputDataSet input - * @param mlAgent agent to run + * @param mlAgent agent to run */ private void saveRootInteractionAndExecute( ActionListener listener, @@ -384,7 +401,8 @@ private void saveRootInteractionAndExecute( List outputs, List modelTensors, MLAgent mlAgent, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); @@ -419,7 +437,8 @@ private void saveRootInteractionAndExecute( modelTensors, listener, memory, - channel + channel, + hookRegistry ), e -> { log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e); @@ -438,7 +457,8 @@ private void saveRootInteractionAndExecute( modelTensors, listener, memory, - channel + channel, + hookRegistry ); } }, ex -> { @@ -447,6 +467,229 @@ private void saveRootInteractionAndExecute( })); } + /** + * Process context management configuration and register context managers in + * hook registry + * + * @param mlAgent the ML agent with context management configuration + * @param hookRegistry the hook registry to register context managers with + * @param inputDataSet the input dataset to update with context management info + */ + private void processContextManagement(MLAgent mlAgent, HookRegistry hookRegistry, RemoteInferenceInputDataSet inputDataSet) { + try { + // Check if context_management is already specified in runtime parameters + String runtimeContextManagement = inputDataSet.getParameters().get("context_management"); + if (runtimeContextManagement != null && !runtimeContextManagement.trim().isEmpty()) { + log.info("Using runtime context management parameter: {}", runtimeContextManagement); + return; // Runtime parameter takes precedence, let MLExecuteTaskRunner handle it + } + + // Check if already processed to avoid duplicate registrations + if ("true".equals(inputDataSet.getParameters().get("context_management_processed"))) { + log.debug("Context management already processed for this execution, skipping"); + return; + } + + // Check if HookRegistry already has callbacks (from previous runtime setup) + // Don't override with inline configuration if runtime config is already active + if (hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.EnhancedPostToolEvent.class) > 0 + || hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.PreLLMEvent.class) > 0) { + log.info("HookRegistry already has active configuration, skipping inline context management"); + return; + } + + ContextManagementTemplate template = null; + String templateName = null; + + if (mlAgent.hasContextManagementTemplate()) { + // Template reference - would need to be resolved from template service + templateName = mlAgent.getContextManagementTemplateName(); + log.info("Agent '{}' has context management template reference: {}", mlAgent.getName(), templateName); + // For now, we'll pass the template name to parameters for MLExecuteTaskRunner + // to handle + inputDataSet.getParameters().put("context_management", templateName); + return; // Let MLExecuteTaskRunner handle template resolution + } else if (mlAgent.getInlineContextManagement() != null) { + // Inline template - process directly + template = mlAgent.getInlineContextManagement(); + templateName = template.getName(); + log.info("Agent '{}' has inline context management configuration: {}", mlAgent.getName(), templateName); + } + + if (template != null) { + // Process inline context management template + processInlineContextManagement(template, hookRegistry); + // Mark as processed to prevent MLExecuteTaskRunner from processing it again + inputDataSet.getParameters().put("context_management_processed", "true"); + inputDataSet.getParameters().put("context_management", templateName); + } + } catch (Exception e) { + log.error("Failed to process context management for agent '{}': {}", mlAgent.getName(), e.getMessage(), e); + // Don't fail the entire execution, just log the error + } + } + + /** + * Process inline context management template and register context managers + * + * @param template the context management template + * @param hookRegistry the hook registry to register with + */ + private void processInlineContextManagement(ContextManagementTemplate template, HookRegistry hookRegistry) { + try { + log.debug("Processing inline context management template: {}", template.getName()); + + // Fresh HookRegistry ensures no duplicate registrations + + // Create context managers from template configuration + List contextManagers = createContextManagers(template); + + if (!contextManagers.isEmpty()) { + // Create hook provider and register with hook registry + org.opensearch.ml.common.contextmanager.ContextManagerHookProvider hookProvider = + new org.opensearch.ml.common.contextmanager.ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks with the registry + hookProvider.registerHooks(hookRegistry); + + log.info("Successfully registered {} context managers from template '{}'", contextManagers.size(), template.getName()); + } else { + log.warn("No context managers created from template '{}'", template.getName()); + } + } catch (Exception e) { + log.error("Failed to process inline context management template '{}': {}", template.getName(), e.getMessage(), e); + } + } + + /** + * Create context managers from template configuration + * + * @param template the context management template + * @return list of created context managers + */ + private List createContextManagers(ContextManagementTemplate template) { + List managers = new ArrayList<>(); + + try { + // Iterate through all hooks and their configurations + for (Map.Entry> entry : template + .getHooks() + .entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + log.debug("Processing hook '{}' with {} configurations", hookName, configs.size()); + + for (org.opensearch.ml.common.contextmanager.ContextManagerConfig config : configs) { + try { + org.opensearch.ml.common.contextmanager.ContextManager manager = createContextManager(config); + if (manager != null) { + managers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } + } catch (Exception e) { + log + .error( + "Failed to create context manager of type '{}' for hook '{}': {}", + config.getType(), + hookName, + e.getMessage(), + e + ); + } + } + } + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", e.getMessage(), e); + } + + return managers; + } + + /** + * Create a single context manager from configuration + * + * @param config the context manager configuration + * @return the created context manager or null if creation failed + */ + private org.opensearch.ml.common.contextmanager.ContextManager createContextManager( + org.opensearch.ml.common.contextmanager.ContextManagerConfig config + ) { + try { + String type = config.getType(); + Map managerConfig = config.getConfig(); + + log.debug("Creating context manager of type: {}", type); + + // Create context manager based on type + switch (type) { + case ToolsOutputTruncateManager.TYPE: + return createToolsOutputTruncateManager(managerConfig); + case SlidingWindowManager.TYPE: + return createSlidingWindowManager(managerConfig); + case SummarizationManager.TYPE: + return createSummarizationManager(managerConfig); + default: + throw new IllegalArgumentException("Failed to create context manager, unknown manager type:"+type); + } + } catch (Exception e) { + log.error("Failed to create context manager: {}", e.getMessage(), e); + return null; + } + } + + /** + * Create ToolsOutputTruncateManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createToolsOutputTruncateManager(Map config) { + log.debug("Creating ToolsOutputTruncateManager with config: {}", config); + ToolsOutputTruncateManager manager = new ToolsOutputTruncateManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + /** + * Create ToolsOutputTruncateManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createSlidingWindowManager(Map config) { + log.debug("Creating SlidingWindowManager with config: {}", config); + SlidingWindowManager manager = new SlidingWindowManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SummarizationManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createSummarizationManager(Map config) { + log.debug("Creating SummarizationManager with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SlidingWindowManager (used for MemoryManager type) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createMemoryManager(Map config) { + log.debug("Creating SlidingWindowManager (MemoryManager) with config: {}", config); + SlidingWindowManager manager = new SlidingWindowManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create ConversationManager (placeholder - using SummarizationManager for now) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createConversationManager(Map config) { + log.debug("Creating ConversationManager (using SummarizationManager as placeholder) with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + private void executeAgent( RemoteInferenceInputDataSet inputDataSet, MLTask mlTask, @@ -457,7 +700,8 @@ private void executeAgent( List modelTensors, ActionListener listener, ConversationIndexMemory memory, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { @@ -466,10 +710,17 @@ private void executeAgent( return; } - MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent); + // Check for agent-level context management configuration (following connector + // pattern) + if (mlAgent.hasContextManagement()) { + processContextManagement(mlAgent, hookRegistry, inputDataSet); + } + + MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent, hookRegistry); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists + // If async is true, index ML task and return the taskID. Also add memoryID to + // the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); if (memoryId != null && !memoryId.isEmpty()) { @@ -606,7 +857,7 @@ private ActionListener createAsyncTaskUpdater( } @VisibleForTesting - protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { + protected MLAgentRunner getAgentRunner(MLAgent mlAgent, HookRegistry hookRegistry) { final MLAgentType agentType = MLAgentType.from(mlAgent.getType().toUpperCase(Locale.ROOT)); switch (agentType) { case FLOW: @@ -640,7 +891,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); case PLAN_EXECUTE_AND_REFLECT: return new MLPlanExecuteAndReflectAgentRunner( @@ -651,7 +903,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); @@ -730,4 +983,5 @@ private void updateInteractionWithFailure(String interactionId, ConversationInde ); } } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..6badaf31fa 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -60,7 +60,9 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -69,6 +71,7 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; @@ -76,8 +79,6 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; -import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -135,6 +136,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private SdkClient sdkClient; private Encryptor encryptor; private StreamingWrapper streamingWrapper; + private static HookRegistry hookRegistry; public MLChatAgentRunner( Client client, @@ -145,6 +147,20 @@ public MLChatAgentRunner( Map memoryFactoryMap, SdkClient sdkClient, Encryptor encryptor + ) { + this(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap, sdkClient, encryptor, null); + } + + public MLChatAgentRunner( + Client client, + Settings settings, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Map toolFactories, + Map memoryFactoryMap, + SdkClient sdkClient, + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -154,6 +170,7 @@ public MLChatAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; } @Override @@ -194,7 +211,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener for (Interaction next : r) { String question = next.getInput(); String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, + // As we store the conversation with empty response first and then update when + // have final answer, // filter out those in-flight requests when run in parallel if (Strings.isNullOrEmpty(response)) { continue; @@ -218,7 +236,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener } params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); } else { List chatHistory = new ArrayList<>(); @@ -239,7 +258,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); } } @@ -320,7 +340,7 @@ private void runReAct( StepListener lastStepListener = firstListener; StringBuilder scratchpadBuilder = new StringBuilder(); - List interactions = new CopyOnWriteArrayList<>(); + final List interactions = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -336,6 +356,7 @@ private void runReAct( if (finalI % 2 == 0) { MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); Map modelOutput = parseLLMOutput( parameters, @@ -454,6 +475,7 @@ private void runReAct( ((ActionListener) nextStepListener).onResponse(res); } } else { + // output is now the processed output from POST_TOOL hook in runTool Object filteredOutput = filterToolOutput(lastToolParams, output); addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); @@ -466,6 +488,7 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); + // Save trace with processed output saveTraceData( conversationIndexMemory, "ReAct", @@ -482,7 +505,9 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); if (!interactions.isEmpty()) { - tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); + String interactionsStr = String.join(", ", interactions); + // Set the interactions parameter - this will be processed by context management + tmpParameters.put(INTERACTIONS, ", " + interactionsStr); } sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); @@ -518,6 +543,26 @@ private void runReAct( ); return; } + // Emit PRE_LLM hook event + if (hookRegistry != null && !interactions.isEmpty()) { + List currentToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); } @@ -530,8 +575,29 @@ private void runReAct( } } + // Emit PRE_LLM hook event for initial LLM call + List initialToolSpecs = new ArrayList<>(toolSpecMap.values()); + tmpParameters.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !interactions.isEmpty()) { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, firstListener); + } private static List createFinalAnswerTensors(List sessionId, List lastThought) { @@ -581,7 +647,9 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); + List newList = new ArrayList<>(); + newList.add(outputString); + additionalInfo.put(toolOutputKey, newList); } } } @@ -602,8 +670,17 @@ private static void runTool( try { String finalAction = action; ActionListener toolListener = ActionListener.wrap(r -> { + // Emit POST_TOOL hook event - common for all tool executions + List postToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterPostTool = AgentContextUtil + .emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); + + // Extract processed output from POST_TOOL hook + String processedToolOutput = contextAfterPostTool.getParameters().get("_current_tool_output"); + Object processedOutput = processedToolOutput != null ? processedToolOutput : r; + if (functionCalling != null) { - String outputResponse = parseResponse(filterToolOutput(toolParams, r)); + String outputResponse = parseResponse(filterToolOutput(toolParams, processedOutput)); List> toolResults = List .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse))); List llmMessages = functionCalling.supply(toolResults); @@ -614,30 +691,30 @@ private static void runTool( .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))), + Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(processedOutput))), INTERACTIONS_PREFIX ) ); } - nextStepListener.onResponse(r); + nextStepListener.onResponse(processedOutput); }, e -> { interactions .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()), + Map + .of( + TOOL_CALL_ID, + toolCallId, + "tool_response", + "Tool " + action + " failed: " + StringUtils.processTextDoc(e.getMessage()) + ), INTERACTIONS_PREFIX ) ); nextStepListener .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - finalAction, - e.getMessage().replaceAll("\\n", "\n") - ) + String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage()) ); }); if (tools.get(action) instanceof MLModelTool) { @@ -666,9 +743,13 @@ private static void runTool( } /** - * In each tool runs, it copies agent parameters, which is tmpParameters into a new set of parameter llmToolTmpParameters, - * after the tool runs, normally llmToolTmpParameters will be discarded, but for some special parameters like SCRATCHPAD_NOTES_KEY, - * some new llmToolTmpParameters produced by the tool run can opt to be copied back to tmpParameters to share across tools in the same interaction + * In each tool runs, it copies agent parameters, which is tmpParameters into a + * new set of parameter llmToolTmpParameters, + * after the tool runs, normally llmToolTmpParameters will be discarded, but for + * some special parameters like SCRATCHPAD_NOTES_KEY, + * some new llmToolTmpParameters produced by the tool run can opt to be copied + * back to tmpParameters to share across tools in the same interaction + * * @param tmpParameters * @param llmToolTmpParameters */ @@ -863,7 +944,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -933,4 +1014,5 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 762b60ca5c..23586e0020 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -21,6 +21,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.MAX_ITERATION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData; @@ -53,9 +54,11 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -69,6 +72,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; @@ -92,6 +96,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { private final Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private HookRegistry hookRegistry; // flag to track if task has been updated with executor memory ids or not private boolean taskUpdated = false; private final Map taskUpdates = new HashMap<>(); @@ -163,7 +168,8 @@ public MLPlanExecuteAndReflectAgentRunner( Map toolFactories, Map memoryFactoryMap, SdkClient sdkClient, - Encryptor encryptor + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -173,6 +179,7 @@ public MLPlanExecuteAndReflectAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; this.plannerPrompt = DEFAULT_PLANNER_PROMPT; this.plannerPromptTemplate = DEFAULT_PLANNER_PROMPT_TEMPLATE; this.reflectPrompt = DEFAULT_REFLECT_PROMPT; @@ -290,9 +297,9 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { + .create(allParams.get(USER_PROMPT_FIELD), memoryId, appType, ActionListener.wrap(memory -> { memory.getMessages(ActionListener.>wrap(interactions -> { - List completedSteps = new ArrayList<>(); + final List completedSteps = new ArrayList<>(); for (Interaction interaction : interactions) { String question = interaction.getInput(); String response = interaction.getResponse(); @@ -386,8 +393,41 @@ private void executePlanningLoop( ); return; } + MLPredictionTaskRequest request; + // Planner agent doesn't use INTERACTIONS for now, reusing the INTERACTIONS to pass over + // completedSteps to context management. + // TODO should refactor the completed steps as message array format, similar to chat agent. + + allParams.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !completedSteps.isEmpty()) { + + Map requestParams = new HashMap<>(allParams); + requestParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps)); + try { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedSteps = contextAfterEvent.getToolInteractions(); + if (updatedSteps != null && !updatedSteps.equals(completedSteps)) { + completedSteps.clear(); + completedSteps.addAll(updatedSteps); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + allParams.put(COMPLETED_STEPS_FIELD, contextInteractions); + // TODO should I always clear interactions after update the completed steps? + allParams.put(INTERACTIONS, ""); + } + } + } catch (Exception e) { + log.error("Failed to emit pre-LLM hook", e); + } - MLPredictionTaskRequest request = new MLPredictionTaskRequest( + } + + request = new MLPredictionTaskRequest( llm.getModelId(), RemoteInferenceMLInput .builder() @@ -443,6 +483,10 @@ private void executePlanningLoop( .inputDataset(RemoteInferenceInputDataSet.builder().parameters(reactParams).build()) .build(); + // Pass hookRegistry to internal agent execution + // TODO need to check if the agentInput already have the hookResgistry? + agentInput.setHookRegistry(hookRegistry); + MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, agentInput); client.execute(MLExecuteTaskAction.INSTANCE, executeRequest, ActionListener.wrap(executeResponse -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java new file mode 100644 index 0000000000..c541045aaf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements a sliding window approach for tool interactions. + * Keeps only the most recent N interactions to prevent context window overflow. + * This manager ensures proper handling of different message types while tool execution flow. + */ +@Log4j2 +public class SlidingWindowManager implements ContextManager { + + public static final String TYPE = "SlidingWindowManager"; + + // Configuration keys + private static final String MAX_MESSAGES_KEY = "max_messages"; + + // Default values + private static final int DEFAULT_MAX_MESSAGES = 20; + + private int maxMessages; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxMessages = parseIntegerConfig(config, MAX_MESSAGES_KEY, DEFAULT_MAX_MESSAGES); + + if (this.maxMessages <= 0) { + log.warn("Invalid max_messages value: {}, using default {}", this.maxMessages, DEFAULT_MAX_MESSAGES); + this.maxMessages = DEFAULT_MAX_MESSAGES; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SlidingWindowManager: maxMessages={}", maxMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + log.debug("No tool interactions to process"); + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int originalSize = interactions.size(); + + if (originalSize <= maxMessages) { + log.debug("Interactions size ({}) is within limit ({}), no truncation needed", originalSize, maxMessages); + return; + } + + // Find safe start point to avoid breaking tool pairs + int startIndex = findSafeStartPoint(interactions, originalSize - maxMessages); + + // Keep the most recent interactions from safe start point + List updatedInteractions = new ArrayList<>(interactions.subList(startIndex, originalSize)); + + // Update toolInteractions in context to keep only the most recent ones + context.setToolInteractions(updatedInteractions); + + // Update the _interactions parameter with smaller size of updated interactions + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + context.setParameters(parameters); + } + parameters.put("_interactions", ", " + String.join(", ", updatedInteractions)); + + int removedMessages = originalSize - updatedInteractions.size(); + log + .info( + "Applied sliding window: kept {} most recent interactions, removed {} older interactions", + updatedInteractions.size(), + removedMessages + ); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe start point that doesn't break assistant-tool message pairs + * Same logic as SummarizationManager but for finding start point + */ + private int findSafeStartPoint(List interactions, int targetStartPoint) { + if (targetStartPoint <= 0) { + return 0; + } + if (targetStartPoint >= interactions.size()) { + return interactions.size(); + } + + int startPoint = targetStartPoint; + + while (startPoint < interactions.size()) { + try { + String messageAtStart = interactions.get(startPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = messageAtStart.contains("toolResult"); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtStart.contains("toolUse"); + boolean nextHasToolResult = false; + if (startPoint + 1 < interactions.size()) { + nextHasToolResult = interactions.get(startPoint + 1).contains("toolResult"); + } + + if (hasToolResult || (hasToolUse && startPoint + 1 < interactions.size() && !nextHasToolResult)) { + startPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", startPoint, e.getMessage()); + startPoint++; + } + } + + return startPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java new file mode 100644 index 0000000000..38ea1b4aac --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java @@ -0,0 +1,442 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static java.lang.Math.min; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.transport.client.Client; + +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.PathNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements summarization approach for tool interactions. + * Summarizes older interactions while preserving recent ones to manage context + * window. + */ +@Log4j2 +public class SummarizationManager implements ContextManager { + + public static final String TYPE = "SummarizationManager"; + + // Configuration keys + private static final String SUMMARY_RATIO_KEY = "summary_ratio"; + private static final String PRESERVE_RECENT_MESSAGES_KEY = "preserve_recent_messages"; + private static final String SUMMARIZATION_MODEL_ID_KEY = "summarization_model_id"; + private static final String SUMMARIZATION_SYSTEM_PROMPT_KEY = "summarization_system_prompt"; + + // Default values + private static final double DEFAULT_SUMMARY_RATIO = 0.3; + private static final int DEFAULT_PRESERVE_RECENT_MESSAGES = 10; + private static final String DEFAULT_SUMMARIZATION_PROMPT = + "You are a interactions summarization agent. Summarize the provided interactions concisely while preserving key information and context."; + + protected double summaryRatio; + protected int preserveRecentMessages; + protected String summarizationModelId; + protected String summarizationSystemPrompt; + protected List activationRules; + private Client client; + + public SummarizationManager(Client client) { + this.client = client; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + this.summaryRatio = parseDoubleConfig(config, SUMMARY_RATIO_KEY, DEFAULT_SUMMARY_RATIO); + this.preserveRecentMessages = parseIntegerConfig(config, PRESERVE_RECENT_MESSAGES_KEY, DEFAULT_PRESERVE_RECENT_MESSAGES); + this.summarizationModelId = (String) config.get(SUMMARIZATION_MODEL_ID_KEY); + this.summarizationSystemPrompt = (String) config.getOrDefault(SUMMARIZATION_SYSTEM_PROMPT_KEY, DEFAULT_SUMMARIZATION_PROMPT); + + // Validate summary ratio + if (summaryRatio < 0.1 || summaryRatio > 0.8) { + log.warn("Invalid summary_ratio value: {}, using default {}", summaryRatio, DEFAULT_SUMMARY_RATIO); + this.summaryRatio = DEFAULT_SUMMARY_RATIO; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SummarizationManager: summaryRatio={}, preserveRecentMessages={}", summaryRatio, preserveRecentMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int totalMessages = interactions.size(); + + // Calculate how many messages to summarize + int messagesToSummarizeCount = Math.max(1, (int) (totalMessages * summaryRatio)); + + // Ensure we don't summarize recent messages + messagesToSummarizeCount = min(messagesToSummarizeCount, totalMessages - preserveRecentMessages); + + if (messagesToSummarizeCount <= 0) { + return; + } + + // Find a safe cut point that doesn't break assistant-tool pairs + int safeCutPoint = findSafeCutPoint(interactions, messagesToSummarizeCount); + + if (safeCutPoint <= 0) { + return; + } + + // Extract messages to summarize and remaining messages + List messagesToSummarize = new ArrayList<>(interactions.subList(0, safeCutPoint)); + List remainingMessages = new ArrayList<>(interactions.subList(safeCutPoint, totalMessages)); + + // Get model ID + String modelId = summarizationModelId; + if (modelId == null) { + Map parameters = context.getParameters(); + if (parameters != null) { + modelId = (String) parameters.get("_llm_model_id"); + } + } + + if (modelId == null) { + log.error("No model ID available for summarization"); + return; + } + + // Prepare summarization parameters + Map summarizationParameters = new HashMap<>(); + summarizationParameters.put("prompt", "Help summarize the following" + StringUtils.toJson(String.join(",", messagesToSummarize))); + summarizationParameters.put("system_prompt", summarizationSystemPrompt); + + executeSummarization(context, modelId, summarizationParameters, safeCutPoint, remainingMessages, interactions); + } + + protected void executeSummarization( + ContextManagerContext context, + String modelId, + Map summarizationParameters, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + CountDownLatch latch = new CountDownLatch(1); + + try { + // Create ML input dataset for remote inference + MLInputDataset inputDataset = RemoteInferenceInputDataSet.builder().parameters(summarizationParameters).build(); + + // Create ML input + MLInput mlInput = MLInput.builder().algorithm(REMOTE).inputDataset(inputDataset).build(); + + // Create prediction request + String tenantId = (String) context.getParameter(TENANT_ID_FIELD); + MLPredictionTaskRequest request = MLPredictionTaskRequest + .builder() + .modelId(modelId) + .mlInput(mlInput) + .tenantId(tenantId) + .build(); + + // Execute prediction + ActionListener listener = ActionListener.wrap(response -> { + try { + String summary = extractSummaryFromResponse(response, context); + processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalInteractions); + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }, e -> { + try { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }); + + client.execute(MLPredictionTaskAction.INSTANCE, request, listener); + + // Wait for summarization to complete (30 second timeout) + latch.await(30, TimeUnit.SECONDS); + + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } + } + + protected void processSummarizationResult( + ContextManagerContext context, + String summary, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + try { + // Create summarized interaction + String summarizedInteraction = "{\"role\":\"assistant\",\"content\":\"Summarized previous interactions: " + + processTextDoc(summary) + + "\"}"; + + // Update interactions: summary + remaining messages + List updatedInteractions = new ArrayList<>(); + updatedInteractions.add(summarizedInteraction); + updatedInteractions.addAll(remainingMessages); + + // Update toolInteractions in context + context.setToolInteractions(updatedInteractions); + + // Update parameters + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); + context.setParameters(parameters); + + log + .info( + "Summarization completed: {} messages summarized, {} messages preserved", + messagesToSummarizeCount, + remainingMessages.size() + ); + + } catch (Exception e) { + log.error("Failed to process summarization result", e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response, ContextManagerContext context) { + try { + MLOutput output = response.getOutput(); + if (output instanceof ModelTensorOutput) { + ModelTensorOutput tensorOutput = (ModelTensorOutput) output; + List mlModelOutputs = tensorOutput.getMlModelOutputs(); + + if (mlModelOutputs != null && !mlModelOutputs.isEmpty()) { + List tensors = mlModelOutputs.get(0).getMlModelTensors(); + if (tensors != null && !tensors.isEmpty()) { + Map dataAsMap = tensors.get(0).getDataAsMap(); + + // Use LLM_RESPONSE_FILTER from agent configuration if available + Map parameters = context.getParameters(); + if (parameters != null + && parameters.containsKey(LLM_RESPONSE_FILTER) + && !parameters.get(LLM_RESPONSE_FILTER).isEmpty()) { + try { + String responseFilter = parameters.get(LLM_RESPONSE_FILTER); + Object filteredResponse = JsonPath.read(dataAsMap, responseFilter); + if (filteredResponse instanceof String) { + String result = ((String) filteredResponse).trim(); + return result; + } else { + String result = StringUtils.toJson(filteredResponse); + return result; + } + } catch (PathNotFoundException e) { + // Fall back to default parsing + } catch (Exception e) { + // Fall back to default parsing + } + } + + // Fallback to default parsing if no filter or filter fails + if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { + Object responseObj = dataAsMap.get("response"); + if (responseObj instanceof String) { + return ((String) responseObj).trim(); + } + } + + // Last resort: return JSON representation + return StringUtils.toJson(dataAsMap); + } + } + } + } catch (Exception e) { + log.error("Failed to extract summary from response", e); + } + + return "Summary generation failed"; + } + + private double parseDoubleConfig(Map config, String key, double defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Double) { + return (Double) value; + } else if (value instanceof Number) { + return ((Number) value).doubleValue(); + } else if (value instanceof String) { + return Double.parseDouble((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid double value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe cut point that doesn't break assistant-tool message pairs + * Exact same logic as Strands agent + */ + private int findSafeCutPoint(List interactions, int targetCutPoint) { + if (targetCutPoint >= interactions.size()) { + return targetCutPoint; + } + // // the current agent logic is when odd number it's tool called result and even number is tool input, should always summarize for + // pairs, so the targetCutPoint needs to be even + // if (targetCutPoint%2==0){ + // return targetCutPoint; + // } else { + // return min(targetCutPoint+1,interactions.size()); + // } + int splitPoint = targetCutPoint; + + while (splitPoint < interactions.size()) { + try { + String messageAtSplit = interactions.get(splitPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = (messageAtSplit.contains("toolResult") || messageAtSplit.contains("tool_call_id")); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtSplit.contains("toolUse"); + boolean nextHasToolResult = false; + // TODO we need better way to handle the tool result based on the llm interfaces. + if (splitPoint + 1 < interactions.size()) { + nextHasToolResult = (interactions.get(splitPoint + 1).contains("toolResult") + || messageAtSplit.contains("tool_call_id")); + } + + if (hasToolResult || (hasToolUse && splitPoint + 1 < interactions.size() && !nextHasToolResult)) { + splitPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", splitPoint, e.getMessage()); + splitPoint++; + } + } + + return splitPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java new file mode 100644 index 0000000000..4fa97c156d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that truncates tool output to prevent context window overflow. + * This manager processes the current tool output and applies length limits. + */ +@Log4j2 +public class ToolsOutputTruncateManager implements ContextManager { + + public static final String TYPE = "ToolsOutputTruncateManager"; + + // Configuration keys + private static final String MAX_OUTPUT_LENGTH_KEY = "max_output_length"; + + // Default values + private static final int DEFAULT_MAX_OUTPUT_LENGTH = 40000; + + private int maxOutputLength; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxOutputLength = parseIntegerConfig(config, MAX_OUTPUT_LENGTH_KEY, DEFAULT_MAX_OUTPUT_LENGTH); + + if (this.maxOutputLength <= 0) { + log.warn("Invalid max_output_length value: {}, using default {}", this.maxOutputLength, DEFAULT_MAX_OUTPUT_LENGTH); + this.maxOutputLength = DEFAULT_MAX_OUTPUT_LENGTH; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized ToolsOutputTruncateManager: maxOutputLength={}", maxOutputLength); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + // Process current tool output from parameters + Map parameters = context.getParameters(); + if (parameters == null) { + log.debug("No parameters available for tool output truncation"); + return; + } + + Object currentToolOutput = parameters.get("_current_tool_output"); + if (currentToolOutput == null) { + log.debug("No current tool output to process"); + return; + } + + String outputString = currentToolOutput.toString(); + int originalLength = outputString.length(); + + if (originalLength <= maxOutputLength) { + log.debug("Tool output length ({}) is within limit ({}), no truncation needed", originalLength, maxOutputLength); + return; + } + + // Truncate the output + String truncatedOutput = outputString.substring(0, maxOutputLength); + + // Add truncation indicator + truncatedOutput += "... [Output truncated - original length: " + originalLength + " characters]"; + + // Update the current tool output in parameters + parameters.put("_current_tool_output", truncatedOutput); + + int truncatedLength = truncatedOutput.length(); + log.info("Tool output truncated: original length {} -> truncated length {}", originalLength, truncatedLength); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index efc84d8f8c..c3843434f7 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -5,1617 +5,317 @@ package org.opensearch.ml.engine.algorithms.agent; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.when; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; -import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; -import java.io.IOException; -import java.net.InetAddress; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import javax.naming.Context; - -import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; -import org.opensearch.ResourceNotFoundException; -import org.opensearch.Version; -import org.opensearch.action.get.GetRequest; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; -import org.opensearch.ml.common.agent.MLMemorySpec; -import org.opensearch.ml.common.agent.MLToolSpec; -import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; -import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; -import org.opensearch.ml.engine.memory.MLMemoryManager; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.action.conversation.GetInteractionAction; -import org.opensearch.ml.memory.action.conversation.GetInteractionResponse; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.remote.metadata.client.SdkClient; -import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; -import com.google.gson.Gson; - -import software.amazon.awssdk.utils.ImmutableMap; - +@SuppressWarnings({ "rawtypes" }) public class MLAgentExecutorTest { @Mock private Client client; - SdkClient sdkClient; - private Settings settings; - @Mock - private ClusterService clusterService; + @Mock - private ClusterState clusterState; + private SdkClient sdkClient; + @Mock - private Metadata metadata; + private ClusterService clusterService; + @Mock private NamedXContentRegistry xContentRegistry; + @Mock - private Map toolFactories; - @Mock - private Map memoryMap; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock - private IndexResponse indexResponse; + private Encryptor encryptor; + @Mock private ThreadPool threadPool; + private ThreadContext threadContext; - @Mock - private Context context; - @Mock - private ConversationIndexMemory.Factory mockMemoryFactory; - @Mock - private ActionListener agentActionListener; - @Mock - private MLAgentRunner mlAgentRunner; @Mock - private ConversationIndexMemory memory; - @Mock - private MLMemoryManager memoryManager; - private MLAgentExecutor mlAgentExecutor; + private ThreadContext.StoredContext storedContext; @Mock - private MLFeatureEnabledSetting mlFeatureEnabledSetting; - - @Captor - private ArgumentCaptor objectCaptor; + private TransportChannel channel; - @Captor - private ArgumentCaptor exceptionCaptor; + @Mock + private ActionListener listener; - private DiscoveryNode localNode = new DiscoveryNode( - "mockClusterManagerNodeId", - "mockClusterManagerNodeId", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + @Mock + private GetResponse getResponse; - MLAgent mlAgent; + private MLAgentExecutor mlAgentExecutor; + private Map toolFactories; + private Map memoryFactoryMap; + private Settings settings; @Before - @SuppressWarnings("unchecked") public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + toolFactories = new HashMap<>(); + memoryFactoryMap = new HashMap<>(); threadContext = new ThreadContext(settings); - memoryMap = ImmutableMap.of("memoryType", mockMemoryFactory); - Mockito.doAnswer(invocation -> { - MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); - MLAgent mlAgent = MLAgent.builder().name("agent").memory(mlMemorySpec).type("flow").build(); - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(true); - Mockito.when(getResponse.getSourceAsBytesRef()).thenReturn(BytesReference.bytes(content)); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.when(clusterService.state()).thenReturn(clusterState); - Mockito.when(clusterState.metadata()).thenReturn(metadata); - when(clusterService.localNode()).thenReturn(localNode); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(true); - Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(this.clusterService.getSettings()).thenReturn(settings); - when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED))); - - // Mock MLFeatureEnabledSetting when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); - when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(true); - - settings = Settings.builder().build(); - mlAgentExecutor = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - mlFeatureEnabledSetting, - null - ) - ); - - } - - @Test - public void test_NoAgentIndex() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(false); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Agent index not found"); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NullInput_ThrowsException() { - mlAgentExecutor.execute(null, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonAgentInput_ThrowsException() { - Input input = new Input() { - @Override - public FunctionName getFunctionName() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return null; - } - }; - mlAgentExecutor.execute(input, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputData_ThrowsException() { - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, null); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputParas_ThrowsException() { - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(null).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, inputDataSet); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test - public void test_HappyCase_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensor_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - List response = Arrays.asList(modelTensor1, modelTensor2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(response, output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensors_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List response = Arrays.asList(modelTensors1, modelTensors2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfString_ReturnsResult() throws IOException { - List response = Arrays.asList("response1", "response2"); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Gson gson = new Gson(); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(gson.toJson(response), output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsString_ReturnsResult() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("response"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("response", output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsModelTensorOutput_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List modelTensorsList = Arrays.asList(modelTensors1, modelTensors2); - ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(modelTensorsList).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensorOutput); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_CreateConversation_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_Regenerate_Validation() throws IOException { - Map params = new HashMap<>(); - params.put(REGENERATE_INTERACTION_ID, "foo"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert.assertEquals(exception.getMessage(), "A memory ID must be provided to regenerate."); - } - - @Test - public void test_Regenerate_GetOriginalInteraction() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(Boolean.TRUE); - return null; - }).when(memoryManager).deleteInteractionAndTrace(Mockito.anyString(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - GetInteractionResponse interactionResponse = Mockito.mock(GetInteractionResponse.class); - Interaction mockInteraction = Mockito.mock(Interaction.class); - Mockito.when(mockInteraction.getInput()).thenReturn("regenerate question"); - Mockito.when(interactionResponse.getInteraction()).thenReturn(mockInteraction); - listener.onResponse(interactionResponse); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - String interactionId = "bar-interaction"; - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, interactionId); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertEquals(params.get(QUESTION), "regenerate question"); - // original interaction got deleted - Mockito.verify(memoryManager, times(1)).deleteInteractionAndTrace(Mockito.eq(interactionId), Mockito.any()); - } - - @Test - public void test_Regenerate_OriginalInteraction_NotExist() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new ResourceNotFoundException("Interaction bar-interaction not found")); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, "bar-interaction"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertNull(params.get(QUESTION)); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Interaction bar-interaction not found"); - } - - @Test - public void test_CreateFlowAgent() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("flow").build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLFlowAgentRunner); - } - - @Test - public void test_CreateChatAgent() { - LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLAgent mlAgent = MLAgent.builder().name("test_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLChatAgentRunner); - } - - @Test(expected = IllegalArgumentException.class) - public void test_InvalidAgent_ThrowsException() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("illegal").build(); - mlAgentExecutor.getAgentRunner(mlAgent); - } - - @Test - public void test_GetModel_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException()); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - @Test - public void test_GetModelDoesNotExist_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(false); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateConversationFailure_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new RuntimeException()); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateInteractionFailure_ThrowsException() { - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onFailure(new RuntimeException()); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AgentRunnerFailure_ReturnsResult() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException()); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AsyncMode_ReturnsTaskId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task_id", 1, 0, 2, true); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput result = (MLTaskOutput) objectCaptor.getValue(); - - Assert.assertEquals("task_id", result.getTaskId()); - Assert.assertEquals("RUNNING", result.getStatus()); - } - - @Test - public void test_AsyncMode_IndexTask_failure() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new Exception("Index Not Found")); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_mcp_connector_requires_mcp_connector_enabled() throws IOException { - // Create an MLAgent with MCP connectors in parameters - Map parameters = new HashMap<>(); - parameters.put(MCP_CONNECTORS_FIELD, "[{\"connector_id\": \"test-connector\"}]"); - - MLAgent mlAgentWithMcpConnectors = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - Collections.emptyList(), - parameters, - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null + // Mock ClusterService for the agent index check + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Simulate index not found + + mlAgentExecutor = new MLAgentExecutor( + client, + sdkClient, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap, + mlFeatureEnabledSetting, + encryptor ); - - // Create GetResponse with the MLAgent that has MCP connectors - XContentBuilder content = mlAgentWithMcpConnectors.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with MCP connector disabled - MLFeatureEnabledSetting disabledMcpSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(disabledMcpSetting.isMultiTenancyEnabled()).thenReturn(false); - when(disabledMcpSetting.isMcpConnectorEnabled()).thenReturn(false); - - MLAgentExecutor mlAgentExecutorWithDisabledMcp = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - disabledMcpSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledMcp).getAgentRunner(Mockito.any()); - - // Execute the agent - mlAgentExecutorWithDisabledMcp.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution fails with the correct error message - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof OpenSearchException); - Assert.assertEquals(exception.getMessage(), ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE); } @Test - public void test_query_planning_agentic_search_enabled() throws IOException { - // Create an MLAgent with QueryPlanningTool - MLAgent mlAgentWithQueryPlanning = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "QueryPlanningTool", - "QueryPlanningTool", - "QueryPlanningTool", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null - ); - - // Create GetResponse with the MLAgent that has QueryPlanningTool - XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with agentic search enabled - MLFeatureEnabledSetting enabledSearchSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(enabledSearchSetting.isMultiTenancyEnabled()).thenReturn(false); - when(enabledSearchSetting.isMcpConnectorEnabled()).thenReturn(true); - - MLAgentExecutor mlAgentExecutorWithEnabledSearch = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - enabledSearchSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithEnabledSearch).getAgentRunner(Mockito.any()); - - // Mock successful execution - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - // Execute the agent - mlAgentExecutorWithEnabledSearch.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution succeeds - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - private AgentMLInput getAgentMLInput() { - Map params = new HashMap<>(); - params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - return new AgentMLInput("test", null, FunctionName.AGENT, dataset); - } - - public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenantId) throws IOException { - - mlAgent = new MLAgent( - "test", - MLAgentType.CONVERSATIONAL.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "memoryType", - "test", - "test", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - isHidden, - tenantId - ); - - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", agentId, 111l, 111l, 111l, true, bytesReference, null, null); - return new GetResponse(getResult); + public void testConstructor() { + assertNotNull(mlAgentExecutor); + assertEquals(client, mlAgentExecutor.getClient()); + assertEquals(settings, mlAgentExecutor.getSettings()); + assertEquals(clusterService, mlAgentExecutor.getClusterService()); + assertEquals(xContentRegistry, mlAgentExecutor.getXContentRegistry()); + assertEquals(toolFactories, mlAgentExecutor.getToolFactories()); + assertEquals(memoryFactoryMap, mlAgentExecutor.getMemoryFactoryMap()); + assertEquals(mlFeatureEnabledSetting, mlAgentExecutor.getMlFeatureEnabledSetting()); + assertEquals(encryptor, mlAgentExecutor.getEncryptor()); } @Test - public void test_BothParentAndRegenerateInteractionId_ThrowsException() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Map params = new HashMap<>(); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent-123"); - params.put(REGENERATE_INTERACTION_ID, "regenerate-456"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testOnMultiTenancyEnabledChanged() { + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + assertTrue(mlAgentExecutor.getIsMultiTenancyEnabled()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert - .assertEquals( - exception.getMessage(), - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." - ); + mlAgentExecutor.onMultiTenancyEnabledChanged(false); + assertFalse(mlAgentExecutor.getIsMultiTenancyEnabled()); } @Test - public void test_ExistingConversation_WithMemoryAndParentInteractionId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testExecuteWithWrongInputType() { + // Test with non-AgentMLInput - create a mock Input that's not AgentMLInput + Input wrongInput = mock(Input.class); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - // Verify memory factory was called with null question and existing memory_id - Mockito.verify(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); + try { + mlAgentExecutor.execute(wrongInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("wrong input", exception.getMessage()); + } } @Test - public void test_AgentFailure_UpdatesInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Agent execution failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "test-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); + public void testExecuteWithNullInputDataSet() { + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, null); - // Verify failure was propagated to listener - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertEquals(testException, exceptionCaptor.getValue()); - - // Verify interaction was updated with failure message - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertTrue(updateContent.get("response").toString().contains("Agent execution failed")); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExistingConversation_MemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory failure for existing conversation - RuntimeException memoryException = new RuntimeException("Memory creation failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); + public void testExecuteWithNullParameters() { + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExecuteAgent_SyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(false); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); + public void testExecuteWithMultiTenancyEnabledButNoTenantId() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder() + .parameters(parameters) + .build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); + + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected OpenSearchStatusException"); + } catch (OpenSearchStatusException exception) { + assertEquals("You don't have permission to access this resource", exception.getMessage()); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + } } @Test - public void test_ExecuteAgent_AsyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); + public void testExecuteWithAgentIndexNotFound() { + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + // Since we can't mock static methods easily, we'll test a different scenario + // This test would need the actual MLIndicesHandler behavior + mlAgentExecutor.execute(agentInput, listener, channel); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(true); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); - Assert.assertEquals("RUNNING", output.getStatus()); + // Verify that the listener was called (the actual behavior will depend on the implementation) + verify(listener, timeout(5000).atLeastOnce()).onFailure(any()); } @Test - public void test_UpdateInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Test failure message"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertEquals("Agent execution failed: Test failure message", updateContent.get("response")); + public void testGetAgentRunnerWithFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLFlowAgentRunner); } @Test - public void test_ConversationMemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", true, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - RuntimeException memoryException = new RuntimeException("Failed to read conversation memory"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq("test question"), Mockito.eq(null), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + public void testGetAgentRunnerWithConversationalFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL_FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLConversationalFlowAgentRunner); } @Test - public void test_AsyncExecution_NullOutput() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(null); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithConversationalAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLChatAgentRunner); } @Test - public void test_AsyncExecution_Failure() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent execution failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithPlanExecuteAndReflectAgent() { + MLAgent agent = createTestAgent(MLAgentType.PLAN_EXECUTE_AND_REFLECT.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLPlanExecuteAndReflectAgentRunner); } @Test - public void test_UpdateInteractionFailure_LogLines() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); + public void testGetAgentRunnerWithUnsupportedAgentType() { + // Create a mock MLAgent instead of using the constructor that validates + MLAgent agent = mock(MLAgent.class); + when(agent.getType()).thenReturn("UNSUPPORTED_TYPE"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); + try { + mlAgentExecutor.getAgentRunner(agent, null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Wrong Agent type", exception.getMessage()); + } } @Test - public void test_UpdateInteractionFailure_ErrorCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Update failed")); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); + public void testProcessOutputWithModelTensorOutput() throws Exception { + ModelTensorOutput output = mock(ModelTensorOutput.class); + when(output.getMlModelOutputs()).thenReturn(Collections.emptyList()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + List modelTensors = new java.util.ArrayList<>(); - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + mlAgentExecutor.processOutput(output, modelTensors); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); + verify(output).getMlModelOutputs(); } @Test - public void test_AsyncTaskUpdate_SuccessCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("success"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + public void testProcessOutputWithString() throws Exception { + String output = "test response"; + List modelTensors = new java.util.ArrayList<>(); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + mlAgentExecutor.processOutput(output, modelTensors); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + assertEquals(1, modelTensors.size()); + assertEquals("response", modelTensors.get(0).getName()); + assertEquals("test response", modelTensors.get(0).getResult()); } - @Test - public void test_AsyncTaskUpdate_FailureCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + private MLAgent createTestAgent(String type) { + return MLAgent + .builder() + .name("test-agent") + .type(type) + .description("Test agent") + .llm(LLMSpec.builder().modelId("test-model").parameters(Collections.emptyMap()).build()) + .tools(Collections.emptyList()) + .parameters(Collections.emptyMap()) + .memory(null) + .createdTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .appType("test-app") + .build(); } @Test - public void test_AgentRunnerException() throws IOException { - // Reset mocks to ensure clean state - Mockito.reset(mlAgentRunner); - - RuntimeException testException = new RuntimeException("Agent runner threw exception"); - Mockito.doThrow(testException).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + public void testContextManagementProcessedFlagPreventsReprocessing() { + // Test that the context_management_processed flag prevents duplicate processing + Map parameters = new HashMap<>(); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + // First check - should allow processing + boolean shouldProcess1 = !"true".equals(parameters.get("context_management_processed")); + assertTrue("First call should allow processing", shouldProcess1); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + // Mark as processed (simulating what the method does) + parameters.put("context_management_processed", "true"); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); + // Second check - should prevent processing + boolean shouldProcess2 = !"true".equals(parameters.get("context_management_processed")); + assertFalse("Second call should prevent processing", shouldProcess2); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..c63db9df4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 7ed4e91b1c..00a6edde13 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -136,6 +136,7 @@ public void setup() { // memory mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + when(memoryMap.get(ConversationIndexMemory.TYPE)).thenReturn(memoryFactory); when(memoryMap.get(anyString())).thenReturn(memoryFactory); when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); @@ -171,7 +172,8 @@ public void setup() { toolFactories, memoryMap, sdkClient, - encryptor + encryptor, + null ); // Setup tools diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java new file mode 100644 index 0000000000..692c0ebf7c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Unit tests for SlidingWindowManager. + */ +public class SlidingWindowManagerTest { + + private SlidingWindowManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + manager = new SlidingWindowManager(); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SlidingWindowManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should initialize with default values without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + + manager.initialize(config); + + // Should initialize without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should always activate when no rules are defined + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should handle empty tool interactions gracefully + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithSmallToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + manager.initialize(config); + + // Add fewer interactions than the limit + addToolInteractionsToContext(5); + int originalSize = context.getToolInteractions().size(); + + manager.execute(context); + + // Tool interactions should remain unchanged + Assert.assertEquals(originalSize, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithLargeToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add more interactions than the limit + addToolInteractionsToContext(10); + + manager.execute(context); + + // Tool interactions should be truncated to the limit + Assert.assertEquals(5, context.getToolInteractions().size()); + + // Parameters should be updated with truncated interactions + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + // Should contain only the last 5 interactions + String[] interactions = interactionsParam.substring(2).split(", "); // Remove ", " prefix + Assert.assertEquals(5, interactions.length); + + // Should keep the most recent interactions (6-10) + for (int i = 0; i < interactions.length; i++) { + String expected = "Tool output " + (6 + i); + Assert.assertEquals(expected, interactions[i]); + } + + // Verify toolInteractions also contain the most recent ones + for (int i = 0; i < context.getToolInteractions().size(); i++) { + String expected = "Tool output " + (6 + i); + String actual = context.getToolInteractions().get(i); + Assert.assertEquals(expected, actual); + } + } + + @Test + public void testExecuteKeepsMostRecentInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Add interactions with identifiable content + addToolInteractionsToContext(7); + + manager.execute(context); + + // Should keep the last 3 interactions (5, 6, 7) + String interactionsParam = (String) context.getParameters().get("_interactions"); + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + Assert.assertEquals("Tool output 5", interactions[0]); + Assert.assertEquals("Tool output 6", interactions[1]); + Assert.assertEquals("Tool output 7", interactions[2]); + } + + @Test + public void testExecuteWithExactLimit() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add exactly the limit number of interactions + addToolInteractionsToContext(5); + + manager.execute(context); + + // Parameters should not be updated since no truncation needed + Assert.assertNull(context.getParameters().get("_interactions")); + } + + @Test + public void testExecuteWithNullToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + context.setToolInteractions(null); + + // Should handle null tool interactions gracefully + manager.execute(context); + + // Should not throw exception + Assert.assertNull(context.getToolInteractions()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + config.put("max_messages", 1); // Set to 1 to force truncation + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should process all string interactions and set _interactions parameter + Assert.assertNotNull(context.getParameters().get("_interactions")); + Assert.assertEquals(1, context.getToolInteractions().size()); // Should keep only 1 + } + + @Test + public void testInvalidMaxMessagesConfig() { + Map config = new HashMap<>(); + config.put("max_messages", "invalid_number"); + + // Should handle invalid config gracefully and use default + manager.initialize(config); + + Assert.assertNotNull(manager); + } + + @Test + public void testExecuteWithNullParameters() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Set parameters to null + context.setParameters(null); + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should create new parameters map and update it + Assert.assertNotNull(context.getParameters()); + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java new file mode 100644 index 0000000000..6f769d2645 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java @@ -0,0 +1,326 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for SummarizationManager. + */ +public class SummarizationManagerTest { + + @Mock + private Client client; + + private SummarizationManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + manager = new SummarizationManager(client); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SummarizationManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + Assert.assertEquals(10, manager.preserveRecentMessages); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.5); + config.put("preserve_recent_messages", 5); + config.put("summarization_model_id", "test-model"); + config.put("summarization_system_prompt", "Custom prompt"); + + manager.initialize(config); + + Assert.assertEquals(0.5, manager.summaryRatio, 0.001); + Assert.assertEquals(5, manager.preserveRecentMessages); + Assert.assertEquals("test-model", manager.summarizationModelId); + Assert.assertEquals("Custom prompt", manager.summarizationSystemPrompt); + } + + @Test + public void testInitializeWithInvalidSummaryRatio() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.9); // Invalid - too high + + manager.initialize(config); + + // Should use default value + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithInsufficientMessages() { + Map config = new HashMap<>(); + config.put("preserve_recent_messages", 10); + manager.initialize(config); + + // Add only 5 interactions - not enough to summarize + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should remain unchanged + Assert.assertEquals(5, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNoModelId() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(20); + + manager.execute(context); + + // Should remain unchanged due to missing model ID + Assert.assertEquals(20, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should handle gracefully - only 2 string interactions, not enough to summarize + Assert.assertEquals(2, context.getToolInteractions().size()); + } + + @Test + public void testProcessSummarizationResult() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(10); + List remainingMessages = List.of("Message 6", "Message 7", "Message 8", "Message 9", "Message 10"); + + manager.processSummarizationResult(context, "Test summary", 5, remainingMessages, context.getToolInteractions()); + + // Should have 1 summary + 5 remaining = 6 total + Assert.assertEquals(6, context.getToolInteractions().size()); + + // First should be summary + String firstOutput = context.getToolInteractions().get(0); + Assert.assertTrue(firstOutput.contains("Test summary")); + } + + @Test + public void testExtractSummaryFromResponseWithLLMResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content"); + context.setParameters(parameters); + + // Create mock response with OpenAI-style structure + Map responseData = new HashMap<>(); + Map choice = new HashMap<>(); + Map message = new HashMap<>(); + message.put("content", "This is the extracted summary content"); + choice.put("message", message); + responseData.put("choices", List.of(choice)); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("This is the extracted summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithBedrockResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with Bedrock-style LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text"); + context.setParameters(parameters); + + // Create mock response with Bedrock-style structure + Map responseData = new HashMap<>(); + Map output = new HashMap<>(); + Map message = new HashMap<>(); + Map content = new HashMap<>(); + content.put("text", "Bedrock extracted summary"); + message.put("content", List.of(content)); + output.put("message", message); + responseData.put("output", output); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Bedrock extracted summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithInvalidFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with invalid LLM_RESPONSE_FILTER path + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.invalid.path"); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Fallback summary content"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + // Should fall back to default parsing + Assert.assertEquals("Fallback summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithoutFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Context without LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Default parsed summary"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Default parsed summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithEmptyFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with empty LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, ""); + context.setParameters(parameters); + + // Create mock response + Map responseData = new HashMap<>(); + responseData.put("response", "Empty filter fallback"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Empty filter fallback", result); + } + + /** + * Helper method to create a mock MLTaskResponse with the given data. + */ + private MLTaskResponse createMockMLTaskResponse(Map responseData) { + ModelTensor tensor = ModelTensor.builder().dataAsMap(responseData).build(); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + return MLTaskResponse.builder().output(output).build(); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index bc6f05f48f..75c5eb8931 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -648,6 +648,16 @@ configurations.all { resolutionStrategy.force "jakarta.json:jakarta.json-api:2.1.3" resolutionStrategy.force "org.opensearch:opensearch:${opensearch_version}" resolutionStrategy.force "org.bouncycastle:bcprov-jdk18on:1.78.1" + // Force consistent Netty versions to resolve conflicts + resolutionStrategy.force 'io.netty:netty-codec-http:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec-http2:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-common:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-buffer:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-handler:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-resolver:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.124.Final' resolutionStrategy.force 'io.projectreactor:reactor-core:3.7.0' resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.9.10" resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.9.23" diff --git a/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java new file mode 100644 index 0000000000..dc9ea439d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Validator for ML Agent registration that performs advanced validation + * requiring service dependencies. + * This validator handles validation that cannot be performed in the request + * object itself, + * such as template existence checking. + */ +@Log4j2 +public class MLAgentRegistrationValidator { + + private final ContextManagementTemplateService contextManagementTemplateService; + + public MLAgentRegistrationValidator(ContextManagementTemplateService contextManagementTemplateService) { + this.contextManagementTemplateService = contextManagementTemplateService; + } + + /** + * Validates an ML agent for registration, performing all necessary validation checks. + * This is the main validation entry point that orchestrates all validation steps. + * + * @param agent the ML agent to validate + * @param listener callback for validation result - onResponse(true) if valid, onFailure with exception if not + */ + public void validateAgentForRegistration(MLAgent agent, ActionListener listener) { + try { + log.debug("Starting agent registration validation for agent: {}", agent.getName()); + + // First, perform basic context management configuration validation + String configError = validateContextManagementConfiguration(agent); + if (configError != null) { + log.error("Agent registration validation failed - configuration error: {}", configError); + listener.onFailure(new IllegalArgumentException(configError)); + return; + } + + // If agent has a context management template reference, validate template access + if (agent.getContextManagementName() != null) { + validateContextManagementTemplateAccess(agent.getContextManagementName(), ActionListener.wrap(templateAccessValid -> { + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + }, templateAccessError -> { + log.error("Agent registration validation failed - template access error: {}", templateAccessError.getMessage()); + listener.onFailure(templateAccessError); + })); + } else { + // No template reference, validation is complete + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + } + } catch (Exception e) { + log.error("Unexpected error during agent registration validation", e); + listener.onFailure(new IllegalArgumentException("Agent validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management template access (following connector access validation pattern). + * This method checks if the template exists and if the user has access to it. + * + * @param templateName the context management template name to validate + * @param listener callback for validation result - onResponse(true) if accessible, onFailure with exception if not + */ + public void validateContextManagementTemplateAccess(String templateName, ActionListener listener) { + try { + log.debug("Validating context management template access: {}", templateName); + + contextManagementTemplateService.getTemplate(templateName, ActionListener.wrap(template -> { + log.debug("Context management template access validation passed: {}", templateName); + listener.onResponse(true); + }, exception -> { + log.error("Context management template access validation failed: {}", templateName, exception); + if (exception instanceof MLResourceNotFoundException) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + templateName)); + } else { + listener + .onFailure( + new IllegalArgumentException("Failed to validate context management template: " + exception.getMessage()) + ); + } + })); + } catch (Exception e) { + log.error("Unexpected error during context management template access validation", e); + listener.onFailure(new IllegalArgumentException("Context management template validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management configuration structure and requirements. + * This method performs comprehensive validation of context management settings. + * + * @param agent the ML agent to validate + * @return validation error message if invalid, null if valid + */ + public String validateContextManagementConfiguration(MLAgent agent) { + // Check for conflicting configuration (both name and inline config specified) + if (agent.getContextManagementName() != null && agent.getContextManagement() != null) { + return "Cannot specify both context_management_name and context_management"; + } + + // Validate context management template name if specified + if (agent.getContextManagementName() != null) { + String templateNameError = validateContextManagementTemplateName(agent.getContextManagementName()); + if (templateNameError != null) { + return templateNameError; + } + } + + // Validate inline context management configuration if specified + if (agent.getContextManagement() != null) { + String inlineConfigError = validateInlineContextManagementConfiguration(agent.getContextManagement()); + if (inlineConfigError != null) { + return inlineConfigError; + } + } + + return null; // Valid + } + + /** + * Validates the context management template name format and basic requirements. + * + * @param templateName the template name to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementTemplateName(String templateName) { + if (templateName == null || templateName.trim().isEmpty()) { + return "Context management template name cannot be null or empty"; + } + + if (templateName.length() > 256) { + return "Context management template name cannot exceed 256 characters"; + } + + if (!templateName.matches("^[a-zA-Z0-9_\\-\\.]+$")) { + return "Context management template name can only contain letters, numbers, underscores, hyphens, and dots"; + } + + return null; // Valid + } + + /** + * Validates the inline context management configuration structure and content. + * + * @param contextManagement the context management configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateInlineContextManagementConfiguration( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement + ) { + // Use the built-in validation from ContextManagementTemplate + if (!contextManagement.isValid()) { + return "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations"; + } + + // Additional validation for specific requirements + if (contextManagement.getName() == null || contextManagement.getName().trim().isEmpty()) { + return "Context management configuration name cannot be null or empty"; + } + + if (contextManagement.getHooks() == null || contextManagement.getHooks().isEmpty()) { + return "Context management configuration must define at least one hook"; + } + + // Validate hook names and configurations + return validateContextManagementHooks(contextManagement.getHooks()); + } + + /** + * Validates context management hooks configuration. + * + * @param hooks the hooks configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementHooks( + java.util.Map> hooks + ) { + // Define valid hook names + java.util.Set validHookNames = java.util.Set + .of("PRE_TOOL", "POST_TOOL", "PRE_LLM", "POST_LLM", "PRE_EXECUTION", "POST_EXECUTION"); + + for (java.util.Map.Entry> entry : hooks + .entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + // Validate hook name + if (!validHookNames.contains(hookName)) { + return "Invalid hook name: " + hookName + ". Valid hook names are: " + validHookNames; + } + + // Validate hook configurations + if (configs == null || configs.isEmpty()) { + return "Hook " + hookName + " must have at least one context manager configuration"; + } + + for (int i = 0; i < configs.size(); i++) { + org.opensearch.ml.common.contextmanager.ContextManagerConfig config = configs.get(i); + if (!config.isValid()) { + return "Invalid context manager configuration at index " + + i + + " in hook " + + hookName + + ": type cannot be null or empty"; + } + + // Validate context manager type + if (config.getType() != null) { + String typeError = validateContextManagerType(config.getType(), hookName, i); + if (typeError != null) { + return typeError; + } + } + } + } + + return null; // Valid + } + + /** + * Validates context manager type for known types. + * + * @param type the context manager type to validate + * @param hookName the hook name for error reporting + * @param index the configuration index for error reporting + * @return validation error message if invalid, null if valid + */ + private String validateContextManagerType(String type, String hookName, int index) { + // Define known context manager types + java.util.Set knownTypes = java.util.Set + .of("ToolsOutputTruncateManager", "SummarizationManager", "MemoryManager", "ConversationManager"); + + // For now, we'll allow unknown types to provide flexibility for future context + // manager types + // This provides extensibility while still validating known ones + if (!knownTypes.contains(type)) { + log + .debug( + "Unknown context manager type '{}' in hook '{}' at index {}. This may be a custom or future type.", + type, + hookName, + index + ); + } + + return null; // Valid - we allow unknown types for extensibility + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index e91d27c9bb..f4d3597f3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -26,6 +26,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.agent.MLAgentRegistrationValidator; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -56,6 +58,7 @@ public class TransportRegisterAgentAction extends HandledTransportAction listener) { + // Validate context management configuration (following connector pattern) + if (agent.hasContextManagementTemplate()) { + // Validate context management template access (similar to connector access validation) + String templateName = agent.getContextManagementTemplateName(); + agentRegistrationValidator.validateContextManagementTemplateAccess(templateName, ActionListener.wrap(hasAccess -> { + if (Boolean.TRUE.equals(hasAccess)) { + continueAgentRegistration(agent, listener); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the context management template provided, template name: " + templateName + ) + ); + } + }, e -> { + log.error("You don't have permission to use the context management template provided, template name: {}", templateName, e); + listener.onFailure(e); + })); + } else if (agent.getInlineContextManagement() != null) { + // Validate inline context management configuration only if it exists (similar to inline connector validation) + validateInlineContextManagement(agent); + continueAgentRegistration(agent, listener); + } else { + // No context management configuration - that's fine, continue with registration + continueAgentRegistration(agent, listener); + } + } + + private void validateInlineContextManagement(MLAgent agent) { + if (agent.getInlineContextManagement() == null) { + log + .error( + "You must provide context management content when creating an agent without providing context management template name!" + ); + throw new IllegalArgumentException( + "You must provide context management content when creating an agent without context management template name!" + ); + } + + // Validate inline context management configuration structure + if (!agent.getInlineContextManagement().isValid()) { + log + .error( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + throw new IllegalArgumentException( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + } + } + + private void continueAgentRegistration(MLAgent agent, ActionListener listener) { String mcpConnectorConfigJSON = (agent.getParameters() != null) ? agent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { // MCP connector provided as tools but MCP feature is disabled, so abort. diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java new file mode 100644 index 0000000000..8ecc0a0e2f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for managing context management template indices. + * Handles index creation, mapping definition, and settings configuration. + */ +@Log4j2 +public class ContextManagementIndexUtils { + + public static final String CONTEXT_MANAGEMENT_TEMPLATES_INDEX = "ml_context_management_templates"; + + private final Client client; + private final ClusterService clusterService; + + public ContextManagementIndexUtils(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + /** + * Check if the context management templates index exists + * @return true if the index exists, false otherwise + */ + public boolean doesIndexExist() { + return clusterService.state().metadata().hasIndex(CONTEXT_MANAGEMENT_TEMPLATES_INDEX); + } + + /** + * Create the context management templates index if it doesn't exist + * @param listener ActionListener for the response + */ + public void createIndexIfNotExists(ActionListener listener) { + if (doesIndexExist()) { + log.debug("Context management templates index already exists"); + listener.onResponse(true); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + CreateIndexRequest createIndexRequest = new CreateIndexRequest(CONTEXT_MANAGEMENT_TEMPLATES_INDEX).settings(getIndexSettings()); + + client.admin().indices().create(createIndexRequest, ActionListener.wrap(createIndexResponse -> { + log.info("Successfully created context management templates index"); + wrappedListener.onResponse(true); + }, exception -> { + if (exception instanceof org.opensearch.ResourceAlreadyExistsException) { + log.debug("Context management templates index already exists"); + wrappedListener.onResponse(true); + } else { + log.error("Failed to create context management templates index", exception); + wrappedListener.onFailure(exception); + } + })); + } catch (Exception e) { + log.error("Error creating context management templates index", e); + listener.onFailure(e); + } + } + + /** + * Get the index settings for context management templates + * @return Settings for the index + */ + private Settings getIndexSettings() { + return Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.auto_expand_replicas", "0-1") + .build(); + } + + /** + * Get the index name for context management templates + * @return The index name + */ + public static String getIndexName() { + return CONTEXT_MANAGEMENT_TEMPLATES_INDEX; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java new file mode 100644 index 0000000000..d754375688 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java @@ -0,0 +1,316 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Service for managing context management templates in OpenSearch. + * Provides CRUD operations for storing and retrieving context management configurations. + */ +@Log4j2 +public class ContextManagementTemplateService { + + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final ClusterService clusterService; + private final ContextManagementIndexUtils indexUtils; + + @Inject + public ContextManagementTemplateService(MLIndicesHandler mlIndicesHandler, Client client, ClusterService clusterService) { + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.clusterService = clusterService; + this.indexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + /** + * Save a context management template to OpenSearch + * @param templateName The name of the template + * @param template The template to save + * @param listener ActionListener for the response + */ + public void saveTemplate(String templateName, ContextManagementTemplate template, ActionListener listener) { + try { + // Validate template + if (!template.isValid()) { + listener.onFailure(new IllegalArgumentException("Invalid context management template")); + return; + } + + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + // Set timestamps + Instant now = Instant.now(); + if (template.getCreatedTime() == null) { + template.setCreatedTime(now); + } + template.setLastModified(now); + + // Set created by if not already set + if (template.getCreatedBy() == null && user != null) { + template.setCreatedBy(user.getName()); + } + + // Ensure index exists first + indexUtils.createIndexIfNotExists(ActionListener.wrap(indexCreated -> { + // Check if template with same name already exists + validateUniqueTemplateName(template.getName(), ActionListener.wrap(exists -> { + if (exists) { + wrappedListener + .onFailure( + new IllegalArgumentException( + "A context management template with name '" + template.getName() + "' already exists" + ) + ); + return; + } + + // Create the index request with proper JSON serialization + IndexRequest indexRequest = new IndexRequest(ContextManagementIndexUtils.getIndexName()) + .id(template.getName()) + .source(template.toXContent(jsonXContent.contentBuilder(), ToXContentObject.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // Execute the index operation + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("Context management template saved successfully: {}", template.getName()); + wrappedListener.onResponse(true); + }, exception -> { + log.error("Failed to save context management template: {}", template.getName(), exception); + wrappedListener.onFailure(exception); + })); + }, wrappedListener::onFailure)); + }, wrappedListener::onFailure)); + } + } catch (Exception e) { + log.error("Error saving context management template", e); + listener.onFailure(e); + } + } + + /** + * Get a context management template by name + * @param templateName The name of the template to retrieve + * @param listener ActionListener for the response + */ + public void getTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + GetRequest getRequest = new GetRequest(ContextManagementIndexUtils.getIndexName(), templateName); + + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + return; + } + + try { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsBytesRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + wrappedListener.onResponse(template); + } catch (Exception e) { + log.error("Failed to parse context management template: {}", templateName, e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + } else { + log.error("Failed to get context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error getting context management template", e); + listener.onFailure(e); + } + } + + /** + * List all context management templates + * @param listener ActionListener for the response + */ + public void listTemplates(ActionListener> listener) { + listTemplates(0, 1000, listener); + } + + /** + * List context management templates with pagination + * @param from Starting index for pagination + * @param size Number of templates to return + * @param listener ActionListener for the response + */ + public void listTemplates(int from, int size, ActionListener> listener) { + try { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener> wrappedListener = ActionListener.runBefore(listener, context::restore); + + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new MatchAllQueryBuilder()).from(from).size(size); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + try { + List templates = new java.util.ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + templates.add(template); + } + wrappedListener.onResponse(templates); + } catch (Exception e) { + log.error("Failed to parse context management templates", e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Return empty list if index doesn't exist + wrappedListener.onResponse(new java.util.ArrayList<>()); + } else { + log.error("Failed to list context management templates", exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error listing context management templates", e); + listener.onFailure(e); + } + } + + /** + * Delete a context management template by name + * @param templateName The name of the template to delete + * @param listener ActionListener for the response + */ + public void deleteTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + DeleteRequest deleteRequest = new DeleteRequest(ContextManagementIndexUtils.getIndexName(), templateName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.delete(deleteRequest, ActionListener.wrap(deleteResponse -> { + boolean deleted = deleteResponse.getResult() == DeleteResponse.Result.DELETED; + if (deleted) { + log.info("Context management template deleted successfully: {}", templateName); + } else { + log.warn("Context management template not found for deletion: {}", templateName); + } + wrappedListener.onResponse(deleted); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener.onResponse(false); + } else { + log.error("Failed to delete context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error deleting context management template", e); + listener.onFailure(e); + } + } + + /** + * Validate that a template name is unique + * @param templateName The template name to check + * @param listener ActionListener for the response (true if exists, false if unique) + */ + private void validateUniqueTemplateName(String templateName, ActionListener listener) { + try { + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new TermQueryBuilder("_id", templateName)).size(1); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + boolean exists = searchResponse.getHits().getTotalHits() != null && searchResponse.getHits().getTotalHits().value() > 0; + listener.onResponse(exists); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Index doesn't exist, so template name is unique + listener.onResponse(false); + } else { + listener.onFailure(exception); + } + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Create XContentParser from registry - utility method + */ + private XContentParser createXContentParserFromRegistry( + NamedXContentRegistry xContentRegistry, + LoggingDeprecationHandler deprecationHandler, + org.opensearch.core.common.bytes.BytesReference bytesReference + ) throws java.io.IOException { + return MediaTypeRegistry.JSON.xContent().createParser(xContentRegistry, deprecationHandler, bytesReference.streamInput()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java new file mode 100644 index 0000000000..86b26b4d6b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import java.util.Map; + +import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory for creating context manager instances from configuration. + * This factory creates the appropriate context manager based on the type + * specified in the configuration and initializes it with the provided settings. + */ +@Log4j2 +public class ContextManagerFactory { + + private final ActivationRuleFactory activationRuleFactory; + private final Client client; + + @Inject + public ContextManagerFactory(ActivationRuleFactory activationRuleFactory, Client client) { + this.activationRuleFactory = activationRuleFactory; + this.client = client; + } + + /** + * Create a context manager instance from configuration + * @param config The context manager configuration + * @return The created context manager instance + * @throws IllegalArgumentException if the manager type is not supported + */ + public ContextManager createContextManager(ContextManagerConfig config) { + if (config == null || config.getType() == null) { + throw new IllegalArgumentException("Context manager configuration and type cannot be null"); + } + + String type = config.getType(); + Map managerConfig = config.getConfig(); + Map activationConfig = config.getActivation(); + + log.debug("Creating context manager of type: {}", type); + + ContextManager manager; + switch (type) { + case "ToolsOutputTruncateManager": + manager = createToolsOutputTruncateManager(managerConfig); + break; + case "SlidingWindowManager": + manager = createSlidingWindowManager(managerConfig); + break; + case "SummarizationManager": + manager = createSummarizationManager(managerConfig); + break; + default: + throw new IllegalArgumentException("Unsupported context manager type: " + type); + } + + // Initialize the manager with configuration + try { + // Merge activation and manager config for initialization + Map fullConfig = new java.util.HashMap<>(); + if (managerConfig != null) { + fullConfig.putAll(managerConfig); + } + if (activationConfig != null) { + fullConfig.put("activation", activationConfig); + } + + manager.initialize(fullConfig); + log.debug("Successfully created and initialized context manager: {}", type); + return manager; + } catch (Exception e) { + log.error("Failed to initialize context manager of type: {}", type, e); + throw new RuntimeException("Failed to initialize context manager: " + type, e); + } + } + + /** + * Create a ToolsOutputTruncateManager instance + */ + private ContextManager createToolsOutputTruncateManager(Map config) { + return new ToolsOutputTruncateManager(); + } + + /** + * Create a SlidingWindowManager instance + */ + private ContextManager createSlidingWindowManager(Map config) { + return new SlidingWindowManager(); + } + + /** + * Create a SummarizationManager instance + */ + private ContextManager createSummarizationManager(Map config) { + return new SummarizationManager(client); + } + + // Add more factory methods for other context manager types as they are implemented + + // private ContextManager createSummarizingManager(Map config) { + // return new SummarizingManager(); + // } + + // private ContextManager createSystemPromptAugmentationManager(Map config) { + // return new SystemPromptAugmentationManager(); + // } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..d5377d1bf1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public CreateContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLCreateContextManagementTemplateAction.NAME, transportService, actionFilters, MLCreateContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLCreateContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Creating context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.saveTemplate(request.getTemplateName(), request.getTemplate(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully created context management template: {}", request.getTemplateName()); + listener.onResponse(new MLCreateContextManagementTemplateResponse(request.getTemplateName(), "created")); + } else { + log.error("Failed to create context management template: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Failed to create context management template")); + } + }, exception -> { + log.error("Error creating context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error creating context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..6c025bc927 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public DeleteContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLDeleteContextManagementTemplateAction.NAME, transportService, actionFilters, MLDeleteContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLDeleteContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Deleting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.deleteTemplate(request.getTemplateName(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully deleted context management template: {}", request.getTemplateName()); + listener.onResponse(new MLDeleteContextManagementTemplateResponse(request.getTemplateName(), "deleted")); + } else { + log.warn("Context management template not found for deletion: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error deleting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error deleting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..011b6852c0 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public GetContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLGetContextManagementTemplateAction.NAME, transportService, actionFilters, MLGetContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLGetContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Getting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.getTemplate(request.getTemplateName(), ActionListener.wrap(template -> { + if (template != null) { + log.debug("Successfully retrieved context management template: {}", request.getTemplateName()); + listener.onResponse(new MLGetContextManagementTemplateResponse(template)); + } else { + log.warn("Context management template not found: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error getting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error getting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java new file mode 100644 index 0000000000..7667ac6cc6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ListContextManagementTemplatesTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public ListContextManagementTemplatesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLListContextManagementTemplatesAction.NAME, transportService, actionFilters, MLListContextManagementTemplatesRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLListContextManagementTemplatesRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Listing context management templates from: {} size: {}", request.getFrom(), request.getSize()); + + contextManagementTemplateService.listTemplates(request.getFrom(), request.getSize(), ActionListener.wrap(templates -> { + log.debug("Successfully retrieved {} context management templates", templates.size()); + // For now, return the size as total. In a real implementation, you'd get the actual total count + listener.onResponse(new MLListContextManagementTemplatesResponse(templates, templates.size())); + }, exception -> { + log.error("Error listing context management templates", exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error listing context management templates", e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 62de34961e..b25ed1afba 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -89,6 +89,12 @@ import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; +import org.opensearch.ml.action.contextmanagement.CreateContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.DeleteContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.GetContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.ListContextManagementTemplatesTransportAction; import org.opensearch.ml.action.controller.CreateControllerTransportAction; import org.opensearch.ml.action.controller.DeleteControllerTransportAction; import org.opensearch.ml.action.controller.DeployControllerTransportAction; @@ -191,6 +197,10 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; @@ -320,15 +330,16 @@ import org.opensearch.ml.processor.MLInferenceIngestProcessor; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLAddMemoriesAction; import org.opensearch.ml.rest.RestMLCancelBatchJobAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLCreateContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLCreateMemoryContainerAction; import org.opensearch.ml.rest.RestMLCreateSessionAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLDeleteControllerAction; import org.opensearch.ml.rest.RestMLDeleteMemoriesByQueryAction; import org.opensearch.ml.rest.RestMLDeleteMemoryAction; @@ -342,6 +353,7 @@ import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConfigAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; +import org.opensearch.ml.rest.RestMLGetContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLGetControllerAction; import org.opensearch.ml.rest.RestMLGetIndexInsightAction; import org.opensearch.ml.rest.RestMLGetIndexInsightConfigAction; @@ -351,6 +363,7 @@ import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLGetToolAction; +import org.opensearch.ml.rest.RestMLListContextManagementTemplatesAction; import org.opensearch.ml.rest.RestMLListToolsAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLPredictionStreamAction; @@ -448,6 +461,7 @@ import org.opensearch.watcher.ResourceWatcherService; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import lombok.SneakyThrows; @@ -616,7 +630,11 @@ public MachineLearningPlugin() {} new ActionHandler<>(MLMcpToolsListAction.INSTANCE, TransportMcpToolsListAction.class), new ActionHandler<>(MLMcpToolsUpdateAction.INSTANCE, TransportMcpToolsUpdateAction.class), new ActionHandler<>(MLMcpToolsUpdateOnNodesAction.INSTANCE, TransportMcpToolsUpdateOnNodesAction.class), - new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class) + new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class), + new ActionHandler<>(MLCreateContextManagementTemplateAction.INSTANCE, CreateContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLGetContextManagementTemplateAction.INSTANCE, GetContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLListContextManagementTemplatesAction.INSTANCE, ListContextManagementTemplatesTransportAction.class), + new ActionHandler<>(MLDeleteContextManagementTemplateAction.INSTANCE, DeleteContextManagementTemplateTransportAction.class) ); } @@ -784,6 +802,17 @@ public Collection createComponents( nodeHelper, mlEngine ); + // Create context management services + ContextManagementTemplateService contextManagementTemplateService = new ContextManagementTemplateService( + mlIndicesHandler, + client, + clusterService + ); + ContextManagerFactory contextManagerFactory = new ContextManagerFactory( + new org.opensearch.ml.common.contextmanager.ActivationRuleFactory(), + client + ); + mlExecuteTaskRunner = new MLExecuteTaskRunner( threadPool, clusterService, @@ -794,7 +823,9 @@ public Collection createComponents( mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ); // Register thread-safe ML objects here. @@ -1070,6 +1101,15 @@ public List getRestHandlers( RestMLMcpToolsRemoveAction restMLRemoveMcpToolsAction = new RestMLMcpToolsRemoveAction(clusterService, mlFeatureEnabledSetting); RestMLMcpToolsListAction restMLListMcpToolsAction = new RestMLMcpToolsListAction(mlFeatureEnabledSetting); RestMLMcpToolsUpdateAction restMLMcpToolsUpdateAction = new RestMLMcpToolsUpdateAction(clusterService, mlFeatureEnabledSetting); + RestMLCreateContextManagementTemplateAction restMLCreateContextManagementTemplateAction = + new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + RestMLGetContextManagementTemplateAction restMLGetContextManagementTemplateAction = new RestMLGetContextManagementTemplateAction( + mlFeatureEnabledSetting + ); + RestMLListContextManagementTemplatesAction restMLListContextManagementTemplatesAction = + new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + RestMLDeleteContextManagementTemplateAction restMLDeleteContextManagementTemplateAction = + new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); return ImmutableList .of( restMLStatsAction, @@ -1146,7 +1186,11 @@ public List getRestHandlers( restMLListMcpToolsAction, restMLMcpToolsUpdateAction, restMLPutIndexInsightConfigAction, - restMLGetIndexInsightConfigAction + restMLGetIndexInsightConfigAction, + restMLCreateContextManagementTemplateAction, + restMLGetContextManagementTemplateAction, + restMLListContextManagementTemplatesAction, + restMLDeleteContextManagementTemplateAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..34387ccb81 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLCreateContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_create_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLCreateContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCreateContextManagementTemplateRequest createRequest = getRequest(request); + return channel -> client + .execute(MLCreateContextManagementTemplateAction.INSTANCE, createRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLCreateContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLCreateContextManagementTemplateRequest + */ + @VisibleForTesting + MLCreateContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + + // Set the template name from URL parameter + template = template.toBuilder().name(templateName).build(); + + return new MLCreateContextManagementTemplateRequest(templateName, template); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..1dbde7f216 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLDeleteContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_delete_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLDeleteContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLDeleteContextManagementTemplateRequest deleteRequest = getRequest(request); + return channel -> client + .execute(MLDeleteContextManagementTemplateAction.INSTANCE, deleteRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLDeleteContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLDeleteContextManagementTemplateRequest + */ + @VisibleForTesting + MLDeleteContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLDeleteContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4089d0aac1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_get_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLGetContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLGetContextManagementTemplateRequest getRequest = getRequest(request); + return channel -> client.execute(MLGetContextManagementTemplateAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLGetContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLGetContextManagementTemplateRequest + */ + @VisibleForTesting + MLGetContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLGetContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..d5020bd5c3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLListContextManagementTemplatesAction extends BaseRestHandler { + private static final String ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION = "ml_list_context_management_templates_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLListContextManagementTemplatesAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/context_management", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLListContextManagementTemplatesRequest listRequest = getRequest(request); + return channel -> client + .execute(MLListContextManagementTemplatesAction.INSTANCE, listRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLListContextManagementTemplatesRequest from a RestRequest + * + * @param request RestRequest + * @return MLListContextManagementTemplatesRequest + */ + @VisibleForTesting + MLListContextManagementTemplatesRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + int from = request.paramAsInt("from", 0); + int size = request.paramAsInt("size", 10); + + return new MLListContextManagementTemplatesRequest(from, size); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 73281e0333..436c659b9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -15,10 +15,18 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.contextmanager.ContextManagerHookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; @@ -50,6 +58,8 @@ public class MLExecuteTaskRunner extends MLTaskRunner wrappedListener = ActionListener.runBefore(listener, ) Input input = request.getInput(); FunctionName functionName = request.getFunctionName(); + + // Handle agent execution with context management + if (FunctionName.AGENT.equals(functionName) && input instanceof AgentMLInput) { + AgentMLInput agentInput = (AgentMLInput) input; + String contextManagementName = getEffectiveContextManagementName(agentInput); + + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + // Execute agent with context management + executeAgentWithContextManagement(request, contextManagementName, channel, listener); + return; + } + } + if (FunctionName.METRICS_CORRELATION.equals(functionName)) { if (!isPythonModelEnabled) { Exception exception = new IllegalArgumentException("This algorithm is not enabled from settings"); @@ -163,10 +189,17 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener { - MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); - listener.onResponse(response); - }, e -> { listener.onFailure(e); }), channel); + + // Default execution for all functions (including agents without context management) + try { + mlEngine.execute(input, ActionListener.wrap(output -> { + MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); + listener.onResponse(response); + }, e -> { listener.onFailure(e); }), channel); + } catch (Exception e) { + log.error("Failed to execute ML function", e); + listener.onFailure(e); + } } catch (Exception e) { mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) @@ -178,4 +211,218 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener + ) { + log.debug("Executing agent with context management: {}", contextManagementName); + + // Lookup context management template + contextManagementTemplateService.getTemplate(contextManagementName, ActionListener.wrap(template -> { + if (template == null) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + contextManagementName)); + return; + } + + try { + // Create context managers from template + java.util.List contextManagers = createContextManagers(template); + + // Create HookRegistry with context managers + HookRegistry hookRegistry = createHookRegistry(contextManagers, template); + + // Set hook registry in agent input + AgentMLInput agentInput = (AgentMLInput) request.getInput(); + agentInput.setHookRegistry(hookRegistry); + + log + .info( + "Executing agent with context management template: {} using {} context managers", + contextManagementName, + contextManagers.size() + ); + + // Execute agent with hook registry + try { + mlEngine.execute(request.getInput(), ActionListener.wrap(output -> { + log.info("Agent execution completed successfully with context management"); + MLExecuteTaskResponse response = new MLExecuteTaskResponse(request.getFunctionName(), output); + listener.onResponse(response); + }, error -> { + log.error("Agent execution failed with context management", error); + listener.onFailure(error); + }), channel); + } catch (Exception e) { + log.error("Failed to execute agent with context management", e); + listener.onFailure(e); + } + + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", contextManagementName, e); + listener.onFailure(e); + } + }, error -> { + log.error("Failed to retrieve context management template: {}", contextManagementName, error); + listener.onFailure(error); + })); + } + + /** + * Gets the effective context management name for an agent. + * Priority: 1) Runtime parameter from execution request, 2) Agent's stored configuration, 3) Runtime parameters set by MLAgentExecutor + * This follows the same pattern as MCP connectors. + * + * @param agentInput the agent ML input + * @return the effective context management name, or null if none configured + */ + private String getEffectiveContextManagementName(AgentMLInput agentInput) { + // Priority 1: Runtime parameter from execution request (user override) + String runtimeContextManagementName = agentInput.getContextManagementName(); + if (runtimeContextManagementName != null && !runtimeContextManagementName.trim().isEmpty()) { + log.debug("Using runtime context management name: {}", runtimeContextManagementName); + return runtimeContextManagementName; + } + + // Priority 2: Check agent's stored configuration directly + String agentId = agentInput.getAgentId(); + if (agentId != null) { + try { + // Use a blocking call to get the agent synchronously + // This is acceptable here since we're in the task execution path + java.util.concurrent.CompletableFuture future = new java.util.concurrent.CompletableFuture<>(); + + try ( + org.opensearch.common.util.concurrent.ThreadContext.StoredContext context = client + .threadPool() + .getThreadContext() + .stashContext() + ) { + client + .get( + new org.opensearch.action.get.GetRequest(org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX, agentId), + org.opensearch.core.action.ActionListener.runBefore(org.opensearch.core.action.ActionListener.wrap(response -> { + if (response.isExists()) { + try { + org.opensearch.core.xcontent.XContentParser parser = + org.opensearch.common.xcontent.json.JsonXContent.jsonXContent + .createParser( + null, + org.opensearch.common.xcontent.LoggingDeprecationHandler.INSTANCE, + response.getSourceAsString() + ); + org.opensearch.core.xcontent.XContentParserUtils + .ensureExpectedToken( + org.opensearch.core.xcontent.XContentParser.Token.START_OBJECT, + parser.nextToken(), + parser + ); + org.opensearch.ml.common.agent.MLAgent mlAgent = org.opensearch.ml.common.agent.MLAgent + .parse(parser); + + if (mlAgent.hasContextManagementTemplate()) { + String templateName = mlAgent.getContextManagementTemplateName(); + future.complete(templateName); + } else { + future.complete(null); + } + } catch (Exception e) { + future.completeExceptionally(e); + } + } else { + future.complete(null); // Agent not found + } + }, future::completeExceptionally), context::restore) + ); + } + + // Wait for the result with a timeout + String contextManagementName = future.get(5, java.util.concurrent.TimeUnit.SECONDS); + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + return contextManagementName; + } + } catch (Exception e) { + // Continue to fallback methods + } + } + + // Priority 3: Agent's runtime parameters (set by MLAgentExecutor in input parameters) + if (agentInput.getInputDataset() instanceof org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) { + org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet dataset = + (org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + // Check if context management has already been processed by MLAgentExecutor (for inline templates) + String contextManagementProcessed = dataset.getParameters().get("context_management_processed"); + if ("true".equals(contextManagementProcessed)) { + log.debug("Context management already processed by MLAgentExecutor, skipping MLExecuteTaskRunner processing"); + return null; // Skip processing in MLExecuteTaskRunner + } + + // Handle template references (not processed by MLAgentExecutor) + String agentContextManagementName = dataset.getParameters().get("context_management"); + if (agentContextManagementName != null && !agentContextManagementName.trim().isEmpty()) { + return agentContextManagementName; + } + } + + return null; + } + + /** + * Create context managers from template configuration + */ + private java.util.List createContextManagers(ContextManagementTemplate template) { + java.util.List contextManagers = new java.util.ArrayList<>(); + + // Iterate through all hooks in the template + for (java.util.Map.Entry> entry : template.getHooks().entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + try { + ContextManager manager = contextManagerFactory.createContextManager(config); + if (manager != null) { + contextManagers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } else { + log.warn("Failed to create context manager of type: {}", config.getType()); + } + } catch (Exception e) { + log.error("Error creating context manager of type: {}", config.getType(), e); + // Continue with other managers + } + } + } + + log.info("Created {} context managers from template: {}", contextManagers.size(), template.getName()); + return contextManagers; + } + + /** + * Create HookRegistry with context managers + */ + private HookRegistry createHookRegistry(java.util.List contextManagers, ContextManagementTemplate template) { + HookRegistry hookRegistry = new HookRegistry(); + + if (!contextManagers.isEmpty()) { + // Create context manager hook provider + ContextManagerHookProvider hookProvider = new ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks + hookProvider.registerHooks(hookRegistry); + + log.debug("Registered context manager hooks for {} managers", contextManagers.size()); + } + + return hookRegistry; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index e5c11215c3..acbd70bb43 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -76,6 +76,7 @@ public class RestActionUtils { public static final String[] UI_METADATA_EXCLUDE = new String[] { "ui_metadata" }; public static final String PARAMETER_TOOL_NAME = "tool_name"; + public static final String PARAMETER_TEMPLATE_NAME = "template_name"; public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java new file mode 100644 index 0000000000..24009b5094 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java @@ -0,0 +1,413 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import junit.framework.TestCase; + +public class MLAgentRegistrationValidatorTests extends TestCase { + + private ContextManagementTemplateService mockTemplateService; + private MLAgentRegistrationValidator validator; + + @Before + public void setUp() { + mockTemplateService = mock(ContextManagementTemplateService.class); + validator = new MLAgentRegistrationValidator(mockTemplateService); + } + + @Test + public void testValidateAgentForRegistration_NoContextManagement() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since no template reference + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateExists() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("existing_template").build(); + + // Mock template service to return a template (exists) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ContextManagementTemplate template = ContextManagementTemplate.builder().name("existing_template").build(); + listener.onResponse(template); + return null; + }).when(mockTemplateService).getTemplate(eq("existing_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + verify(mockTemplateService).getTemplate(eq("existing_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateNotFound() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("nonexistent_template").build(); + + // Mock template service to return template not found + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Context management template not found: nonexistent_template")); + return null; + }).when(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Context management template not found: nonexistent_template")); + + verify(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateServiceError() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("error_template").build(); + + // Mock template service to return an error + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Service error")); + return null; + }).when(mockTemplateService).getTemplate(eq("error_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Failed to validate context management template")); + + verify(mockTemplateService).getTemplate(eq("error_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_InlineContextManagement() throws InterruptedException { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since using inline configuration + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidTemplateName() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue( + error + .get() + .getMessage() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidInlineConfiguration() throws InterruptedException { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Invalid hook name: INVALID_HOOK")); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateContextManagementConfiguration_ValidTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("valid_template_name").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_ValidInlineConfig() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot be null or empty")); + } + + @Test + public void testValidateContextManagementConfiguration_TooLongTemplateName() { + String longName = "a".repeat(257); + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidTemplateNameCharacters() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidHookName() { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyHookConfigs() { + Map> emptyHooks = new HashMap<>(); + emptyHooks.put("POST_TOOL", List.of()); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(emptyHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Hook POST_TOOL must have at least one context manager configuration")); + } + + @Test + public void testValidateContextManagementConfiguration_Conflict() { + // This test should verify that the MLAgent constructor throws an exception + // when both context management name and inline config are provided + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + try { + MLAgent agent = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Cannot specify both context_management_name and context_management", e.getMessage()); + } + } + + @Test + public void testValidateContextManagementConfiguration_InvalidInlineConfig() { + // This test should verify that the MLAgent constructor throws an exception + // when invalid context management configuration is provided + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("invalid_template") + .hooks(new HashMap<>()) + .build(); + + try { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidContextManagement).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Invalid context management configuration", e.getMessage()); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java index 63fd5216e2..5ec7f38311 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -353,6 +353,8 @@ private GetResponse prepareMLAgent(String agentId, boolean isHidden, String tena Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 8a5e081855..7bed2a7225 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -335,6 +335,8 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenan Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java index b9e7323b7e..0818845cac 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -42,6 +42,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -96,6 +97,9 @@ public class RegisterAgentTransportActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -116,7 +120,8 @@ public void setup() throws IOException { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); indexResponse = new IndexResponse(new ShardId(ML_AGENT_INDEX, "_na_", 0), "AGENT_ID", 1, 0, 2, true); } @@ -510,7 +515,8 @@ public void test_execute_registerAgent_MCPConnectorDisabled() { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); disabledAction.doExecute(task, request, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java new file mode 100644 index 0000000000..c8d1c8953f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +public class ContextManagementIndexUtilsTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private AdminClient adminClient; + + @Mock + private IndicesAdminClient indicesAdminClient; + + private ContextManagementIndexUtils contextManagementIndexUtils; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + contextManagementIndexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + @Test + public void testGetIndexName() { + String indexName = ContextManagementIndexUtils.getIndexName(); + assertEquals("ml_context_management_templates", indexName); + } + + @Test + public void testDoesIndexExist_True() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertTrue(exists); + } + + @Test + public void testDoesIndexExist_False() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertFalse(exists); + } + + @Test + public void testCreateIndexIfNotExists_IndexAlreadyExists() { + // Mock index already exists + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + verify(indicesAdminClient, never()).create(any(), any()); + } + + @Test + public void testCreateIndexIfNotExists_Success() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock successful index creation + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + CreateIndexResponse response = mock(CreateIndexResponse.class); + createListener.onResponse(response); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + + // Verify the create request was made with correct settings + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(CreateIndexRequest.class); + verify(indicesAdminClient).create(requestCaptor.capture(), any()); + + CreateIndexRequest request = requestCaptor.getValue(); + assertEquals("ml_context_management_templates", request.index()); + + Settings indexSettings = request.settings(); + assertEquals("1", indexSettings.get("index.number_of_shards")); + assertEquals("1", indexSettings.get("index.number_of_replicas")); + assertEquals("0-1", indexSettings.get("index.auto_expand_replicas")); + } + + @Test + public void testCreateIndexIfNotExists_ResourceAlreadyExistsException() { + // Mock index doesn't exist initially + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock ResourceAlreadyExistsException (race condition) + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(new ResourceAlreadyExistsException("Index already exists")); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + } + + @Test + public void testCreateIndexIfNotExists_OtherException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Test exception"); + + // Mock other exception + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(testException); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } + + @Test + public void testCreateIndexIfNotExists_UnexpectedException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception during setup + when(client.admin()).thenThrow(testException); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java new file mode 100644 index 0000000000..c8b7391908 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java @@ -0,0 +1,351 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; + +public class ContextManagementTemplateServiceTests extends OpenSearchTestCase { + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + private ContextManagementTemplateService contextManagementTemplateService; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + // Mock cluster service dependencies for proper setup + org.opensearch.cluster.ClusterState clusterState = mock(org.opensearch.cluster.ClusterState.class); + org.opensearch.cluster.metadata.Metadata metadata = mock(org.opensearch.cluster.metadata.Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Default to index not existing + + contextManagementTemplateService = new ContextManagementTemplateService(mlIndicesHandler, client, clusterService); + } + + @Test + public void testConstructor() { + assertNotNull(contextManagementTemplateService); + } + + @Test + public void testSaveTemplate_InvalidTemplate() { + String templateName = "test_template"; + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate(templateName, template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Invalid context management template", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(5, 20, listener); + + // Verify that the method was called - the actual OpenSearch interaction would be complex to mock + // This at least exercises the method signature and basic flow + verify(client).threadPool(); + } + + @Test + public void testListTemplates_DefaultPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + // Verify that the method was called - this exercises the default pagination path + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_NullTemplate() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof NullPointerException); + } + + @Test + public void testSaveTemplate_ValidTemplate() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called - the method will fail due to complex mocking requirements + // but this covers the validation path and timestamp setting + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } + + @Test + public void testSaveTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenThrow(new RuntimeException("Validation error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("Validation error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPaginationExceptionInTryBlock() { + // Test exception handling in the outer try-catch block for paginated version + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(10, 50, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_NullListener() { + // This should not throw an exception, but we can test that the method handles it gracefully + try { + contextManagementTemplateService.listTemplates(null); + // If we get here without exception, that's fine - the method should handle null listeners gracefully + } catch (Exception e) { + // If an exception is thrown, it should be a meaningful one + assertTrue(e instanceof IllegalArgumentException || e instanceof NullPointerException); + } + } + + @Test + public void testGetTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testDeleteTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_TemplateWithExistingCreatedTime() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(java.time.Instant.now()); // Already has created time + when(template.getCreatedBy()).thenReturn("existing_user"); // Already has created by + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called and existing values were checked + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should call setLastModified but not setCreatedTime or setCreatedBy since they exist + verify(template).setLastModified(any(java.time.Instant.class)); + verify(template, never()).setCreatedTime(any(java.time.Instant.class)); + verify(template, never()).setCreatedBy(anyString()); + } + + @Test + public void testSaveTemplate_TemplateWithNullCreatedBy() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should set both created time and last modified + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java new file mode 100644 index 0000000000..f196da28d9 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +public class ContextManagerFactoryTests { + + private ContextManagerFactory contextManagerFactory; + private ActivationRuleFactory activationRuleFactory; + private Client client; + + @Before + public void setUp() { + activationRuleFactory = mock(ActivationRuleFactory.class); + client = mock(Client.class); + contextManagerFactory = new ContextManagerFactory(activationRuleFactory, client); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManagerWithParameters() { + // Arrange + Map parameters = Map.of("maxLength", 1000); + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", parameters, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_SlidingWindowManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SlidingWindowManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SlidingWindowManager); + } + + @Test + public void testCreateContextManager_SummarizationManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SummarizationManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SummarizationManager); + } + + @Test + public void testCreateContextManager_UnknownType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("UnknownManager", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for unknown manager type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } + + @Test + public void testCreateContextManager_NullConfig() { + // Act & Assert + try { + contextManagerFactory.createContextManager(null); + fail("Expected IllegalArgumentException for null config"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_NullType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig(null, null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for null type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_EmptyType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for empty type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } + + @Test + public void testContextManagerHookProvider_SelectiveRegistration() { + // Test that ContextManagerHookProvider only registers hooks for configured managers + java.util.Map> hookToManagersMap = new java.util.HashMap<>(); + + // Test 1: Only POST_TOOL configured + hookToManagersMap.put("POST_TOOL", java.util.Arrays.asList("ToolsOutputTruncateManager")); + + // Simulate the registration logic + java.util.Set registeredHooks = new java.util.HashSet<>(); + if (hookToManagersMap.containsKey("PRE_LLM")) { + registeredHooks.add("PRE_LLM"); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registeredHooks.add("POST_TOOL"); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registeredHooks.add("POST_MEMORY"); + } + + // Assert only POST_TOOL is registered + assertTrue("Should only register POST_TOOL hook", registeredHooks.size() == 1 && registeredHooks.contains("POST_TOOL")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..859cc1fbd7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class CreateContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private CreateContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new CreateContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLCreateContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("created", response.getStatus()); + } + + @Test + public void testDoExecute_SaveFailure() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Failed to create context management template", exception.getMessage()); + } + + @Test + public void testDoExecute_SaveException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException saveException = new RuntimeException("Database error"); + + // Mock exception during template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onFailure(saveException); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(saveException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).saveTemplate(any(), any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..a160fabc59 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class DeleteContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private DeleteContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new DeleteContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLDeleteContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("deleted", response.getStatus()); + } + + @Test + public void testDoExecute_DeleteFailure() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: test_template", exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).deleteTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..4bb1328518 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class GetContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private GetContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new GetContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLGetContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(template, response.getTemplate()); + } + + @Test + public void testDoExecute_TemplateNotFound() { + String templateName = "nonexistent_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock template not found (null response) + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(null); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: " + templateName, exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).getTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).getTemplate(eq(templateName), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java new file mode 100644 index 0000000000..c5951ee868 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class ListContextManagementTemplatesTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private ListContextManagementTemplatesTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new ListContextManagementTemplatesTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1"), createTestTemplate("template2")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + } + + @Test + public void testDoExecute_EmptyList() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List emptyTemplates = Collections.emptyList(); + + // Mock empty template list + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(emptyTemplates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(emptyTemplates, response.getTemplates()); + assertTrue(response.getTemplates().isEmpty()); + } + + @Test + public void testDoExecute_ServiceException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).listTemplates(anyInt(), anyInt(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + int from = 5; + int size = 20; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + @Test + public void testDoExecute_CustomPagination() { + int from = 10; + int size = 5; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template3")); + + // Mock successful template listing with custom pagination + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + + // Verify the service was called with custom pagination parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + private ContextManagementTemplate createTestTemplate(String name) { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name(name) + .description("Test template " + name) + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..ea585e0459 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLCreateContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLCreateContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCreateContextManagementTemplateAction action = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_create_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceOnlyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithInvalidJsonContent() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray("invalid json"), XContentType.JSON) + .build(); + + assertThrows(Exception.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLCreateContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + assertNotNull(result.getTemplate()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + } + + private String getValidTemplateContent() { + return "{\n" + + " \"description\": \"Test template\",\n" + + " \"hooks\": {\n" + + " \"PreLLMEvent\": [\n" + + " {\n" + + " \"type\": \"SummarizationManager\",\n" + + " \"config\": {\n" + + " \"summary_ratio\": 0.3,\n" + + " \"preserve_recent_messages\": 10\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..520dd05d19 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLDeleteContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLDeleteContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteContextManagementTemplateAction action = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLDeleteContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..abe0d3edaa --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLGetContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLGetContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetContextManagementTemplateAction action = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLGetContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java new file mode 100644 index 0000000000..d1f56a934a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLListContextManagementTemplatesActionTests extends OpenSearchTestCase { + private RestMLListContextManagementTemplatesAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLListContextManagementTemplatesAction action = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_list_context_management_templates_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + public void testPrepareRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(5, 20); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(5, capturedRequest.getFrom()); + assertEquals(20, capturedRequest.getSize()); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithDefaultPagination() throws Exception { + RestRequest request = getRestRequest(); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(0, result.getFrom()); + assertEquals(10, result.getSize()); + } + + public void testGetRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(15, 25); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(15, result.getFrom()); + assertEquals(25, result.getSize()); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithInvalidPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(-1, -5); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + // The REST action passes through the parameters as-is, validation happens at the service level + assertEquals(-1, result.getFrom()); + assertEquals(-5, result.getSize()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } + + private RestRequest getRestRequestWithPagination(int from, int size) { + Map params = new HashMap<>(); + params.put("from", String.valueOf(from)); + params.put("size", String.valueOf(size)); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 1f53744661..d1d0eca084 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -33,6 +33,8 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -78,6 +80,10 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { DiscoveryNodeHelper nodeHelper; @Mock ClusterApplierService clusterApplierService; + @Mock + ContextManagementTemplateService contextManagementTemplateService; + @Mock + ContextManagerFactory contextManagerFactory; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -138,7 +144,9 @@ public void setup() { mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ) );