diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..89c0492a89 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; @@ -34,6 +35,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import java.security.PrivilegedActionException; @@ -57,10 +59,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -68,6 +73,8 @@ import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; @@ -83,6 +90,7 @@ import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; +import com.jayway.jsonpath.JsonPath; import lombok.Data; import lombok.NoArgsConstructor; @@ -125,6 +133,8 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + + ". Here's a summary of the steps completed so far:\n\n%s"; private Client client; private Settings settings; @@ -321,7 +331,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -413,7 +422,9 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + llm, + tenantId ); return; } @@ -514,7 +525,9 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + llm, + tenantId ); return; } @@ -887,11 +900,63 @@ private void handleMaxIterationsReached( Map additionalInfo, AtomicReference lastThought, int maxIterations, + Map tools, + LLMSpec llmSpec, + String tenantId + ) { + ActionListener responseListener = ActionListener.wrap(response -> { + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + response, + tools + ); + }, listener::onFailure); + + generateLLMSummary( + traceTensors, + llmSpec, + tenantId, + ActionListener + .wrap( + summary -> responseListener + .onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)), + e -> { + log.error("Failed to generate LLM summary, using fallback strategy", e); + String fallbackResponse = (lastThought.get() != null + && !lastThought.get().isEmpty() + && !"null".equals(lastThought.get())) + ? String + .format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + responseListener.onResponse(fallbackResponse); + } + ) + ); + } + + private void sendTraditionalMaxIterationsResponse( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + String response, Map tools ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); sendFinalAnswer( sessionId, listener, @@ -903,11 +968,104 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + response ); cleanUpResource(tools); } + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + if (stepsSummary == null || stepsSummary.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + + // Convert ModelTensors to strings before joining + List stepStrings = new ArrayList<>(); + for (ModelTensors tensor : stepsSummary) { + if (tensor != null && tensor.getMlModelTensors() != null) { + for (ModelTensor modelTensor : tensor.getMlModelTensors()) { + if (modelTensor.getResult() != null) { + stepStrings.add(modelTensor.getResult()); + } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { + stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); + } + } + } + } + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); + summaryParams.put(PROMPT, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); + + ActionRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + String summary = extractSummaryFromResponse(mlTaskResponse); + if (summary == null) { + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; + } + listener.onResponse(summary); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public String extractSummaryFromResponse(MLTaskResponse response) { + try { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; + } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey("response")) { + return String.valueOf(dataMap.get("response")).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); + return null; + } catch (Exception e) { + log.error("Failed to extract summary from response", e); + throw new RuntimeException("Failed to extract summary from response", e); + } + } + private void saveMessage( ConversationIndexMemory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 762b60ca5c..08f9e6dca8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -33,6 +33,7 @@ import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.util.ArrayList; import java.util.HashMap; @@ -363,27 +364,8 @@ private void executePlanningLoop( int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); String parentInteractionId = allParams.get(MLAgentExecutor.PARENT_INTERACTION_ID); - // completedSteps stores the step and its result, hence divide by 2 to find total steps completed - // on reaching max iteration, update parent interaction question with last executed step rather than task to allow continue using - // memory_id if (stepsExecuted >= maxSteps) { - String finalResult = String - .format( - "Max Steps Limit Reached. Use memory_id with same task to restart. \n " - + "Last executed step: %s, \n " - + "Last executed step result: %s", - completedSteps.get(completedSteps.size() - 2), - completedSteps.getLast() - ); - saveAndReturnFinalResult( - (ConversationIndexMemory) memory, - parentInteractionId, - allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), - allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), - finalResult, - null, - finalListener - ); + handleMaxStepsReached(llm, allParams, completedSteps, memory, parentInteractionId, finalListener); return; } @@ -740,4 +722,132 @@ static List createModelTensors( Map getTaskUpdates() { return taskUpdates; } + + private void handleMaxStepsReached( + LLMSpec llm, + Map allParams, + List completedSteps, + Memory memory, + String parentInteractionId, + ActionListener finalListener + ) { + int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); + log.info("[SUMMARY] Max steps reached. Completed steps: {}", completedSteps.size()); + + ActionListener responseListener = ActionListener.wrap(response -> { + saveAndReturnFinalResult( + (ConversationIndexMemory) memory, + parentInteractionId, + allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), + allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), + response, + null, + finalListener + ); + }, finalListener::onFailure); + + generateSummary(llm, completedSteps, allParams.get(TENANT_ID_FIELD), ActionListener.wrap(summary -> { + log.info("Summary generated successfully"); + responseListener + .onResponse( + String.format("Max Steps Limit (%d) Reached. Here's a summary of the steps completed so far:\n\n%s", maxSteps, summary) + ); + }, e -> { + log.error("Summary generation failed, using fallback", e); + String fallbackResult = completedSteps.isEmpty() || completedSteps.size() < 2 + ? String.format("Max Steps Limit (%d) Reached. Use memory_id with same task to restart.", maxSteps) + : String + .format( + "Max Steps Limit (%d) Reached. Use memory_id with same task to restart. \n " + + "Last executed step: %s, \n " + + "Last executed step result: %s", + maxSteps, + completedSteps.get(completedSteps.size() - 2), + completedSteps.getLast() + ); + responseListener.onResponse(fallbackResult); + })); + } + + private void generateSummary(LLMSpec llmSpec, List completedSteps, String tenantId, ActionListener listener) { + if (completedSteps == null || completedSteps.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", completedSteps)); + summaryParams.put(PROMPT_FIELD, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); + + MLPredictionTaskRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary == null || summary.trim().isEmpty()) { + log.error("Extracted summary is empty"); + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; + } + listener.onResponse(summary); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; + } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); + return null; + } catch (Exception e) { + log.error("Summary extraction failed", e); + throw new RuntimeException("Failed to extract summary from response", e); + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 9ff33ecaa9..1ea97f1ebc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -140,4 +140,7 @@ public class PromptTemplate { - Avoid making assumptions and relying on implicit knowledge. - Your response must be self-contained and ready for the planner to use without modification. Never end with a question. - Break complex searches into simpler queries when appropriate."""; + + public static final String SUMMARY_PROMPT_TEMPLATE = + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..8c0a288746 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1031,11 +1031,19 @@ public void testMaxIterationsReached() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response that doesn't contain final_answer to force max iterations - Mockito - .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response without final_answer to force max iterations + // Second call: Summary LLM response + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("The agent attempted to use the first tool").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); @@ -1051,9 +1059,15 @@ public void testMaxIterationsReached() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains max iterations message + // Verify the response contains max iterations message with summary String response = (String) agentOutput.get(0).getDataAsMap().get("response"); - assertEquals("Agent reached maximum iterations (1) without completing the task", response); + assertTrue( + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) + ); + assertTrue(response.contains("The agent attempted to use the first tool")); } @Test @@ -1070,9 +1084,17 @@ public void testMaxIterationsReachedWithValidThought() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response with valid thought + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response with valid thought to trigger max iterations + // Second call: Summary LLM fails to trigger fallback Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); @@ -1090,7 +1112,7 @@ public void testMaxIterationsReachedWithValidThought() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains the last valid thought instead of max iterations message + // Verify the response contains the last valid thought (fallback when summary fails) String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals( "Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool", @@ -1171,4 +1193,226 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testMaxIterationsWithSummaryEnabled() { + // Create LLM spec with max_iteration = 1 to simplify test + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset and setup fresh mocks + Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM response with result field instead of dataAsMap + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains summary message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertTrue( + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) + ); + assertTrue(response.contains("Summary: Analysis step was attempted")); + } + + @Test + public void testMaxIterationsWithSummaryDisabled() { + // Create LLM spec with max_iteration = 1 + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM fails + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response uses fallback strategy with last thought + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); + } + + @Test + public void testExtractSummaryFromResponse() { + MLTaskResponse response = MLTaskResponse + .builder() + .output( + ModelTensorOutput + .builder() + .mlModelOutputs( + Arrays + .asList( + ModelTensors + .builder() + .mlModelTensors(Arrays.asList(ModelTensor.builder().result("Valid summary text").build())) + .build() + ) + ) + .build() + ) + .build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Valid summary text", result); + } + + @Test + public void testGenerateLLMSummaryWithNullSteps() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + ActionListener listener = Mockito.mock(ActionListener.class); + + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testExtractSummaryFromResponse_WithResponseField() { + Map dataMap = new HashMap<>(); + dataMap.put("response", "Summary from response field"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Summary from response field", result); + } + + @Test + public void testExtractSummaryFromResponse_WithNullDataMap() { + ModelTensor tensor = ModelTensor.builder().build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } + + @Test + public void testExtractSummaryFromResponse_WithEmptyDataMap() { + Map dataMap = new HashMap<>(); + dataMap.put("other_field", "some value"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } + + @Test + public void testExtractSummaryFromResponse_ThrowsException_FallbackStrategyUsed() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "Analyzing the problem", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + Map invalidDataMap = new HashMap<>(); + invalidDataMap.put("output", new HashMap<>()); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(invalidDataMap).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: Analyzing the problem", response); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 7ed4e91b1c..f1906e8630 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -382,7 +382,153 @@ public void testMessageHistoryLimits() { assertEquals("3", executorParams.get("message_history_limit")); } - // ToDo: add test case for when max steps is reached + @Test + public void testMaxStepsReachedWithSummary() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary of work done").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Summary of work done")); + } + + @Test + public void testMaxStepsReachedWithSummaryGeneration() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Generated summary of completed steps").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Generated summary of completed steps")); + } + + @Test + public void testMaxStepsReachedWithSummaryFailure() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Summary generation failed")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testMaxStepsReachedWithEmptyCompletedSteps() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Collections.emptyList()); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -855,4 +1001,145 @@ public void testExecutionWithNullStepResult() { // Verify that onFailure was called with the expected exception verify(agentActionListener).onFailure(any(IllegalStateException.class)); } + + @Test + public void testMaxStepsWithSingleCompletedStep() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Arrays.asList(Interaction.builder().id("i1").input("step1").response("").build())); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testSummaryExtractionWithResultField() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().result("Summary from result").build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Summary from result")); + } + + @Test + public void testSummaryExtractionWithEmptyResponse() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", " ")).build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit")); + } + + @Test + public void testSummaryExtractionWithNullOutput() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + when(mlTaskResponse.getOutput()).thenReturn(null); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(any()); + } }