From aea82c7f9823b23ee7775d7666df437dfc54fbd3 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 14:51:51 +0800 Subject: [PATCH 01/24] init max step summary Signed-off-by: Jiaru Jiang # Conflicts: # ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java --- build.gradle | 2 + .../algorithms/agent/MLChatAgentRunner.java | 149 +++++++++++++++++- .../algorithms/agent/PromptTemplate.java | 3 + .../agent/MLChatAgentRunnerTest.java | 89 +++++++++++ 4 files changed, 241 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 91267a067b..2df2f50656 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) + force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -95,6 +96,7 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" + resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}" resolutionStrategy.force "io.netty:netty-codec:${versions.netty}" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 103f3f89b3..62f51db848 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -34,6 +34,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import java.security.PrivilegedActionException; @@ -122,9 +123,12 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; + public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration"; private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = + "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -321,6 +325,7 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); + List executionSteps = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -379,6 +384,17 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); + // Record execution step for summary + if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { + executionSteps.add(String.format("Thought: %s", thought.trim())); + } + if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { + String actionDesc = actionInput != null && !"null".equals(actionInput) + ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) + : String.format("Action: %s", action.trim()); + executionSteps.add(actionDesc); + } + traceTensors .add( ModelTensors @@ -413,7 +429,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -466,6 +486,10 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); + // Record tool result for summary + String outputSummary = outputToOutputString(filteredOutput); + executionSteps.add(String.format("Result: %s", outputSummary)); + saveTraceData( conversationIndexMemory, "ReAct", @@ -514,7 +538,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -875,6 +903,65 @@ public static void returnFinalResponse( } private void handleMaxIterationsReached( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + AtomicReference lastThought, + int maxIterations, + Map tools, + Map parameters, + List executionSteps, + LLMSpec llmSpec, + String tenantId + ) { + boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); + + if (shouldSummarize && !executionSteps.isEmpty()) { + generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + incompleteResponse + ); + cleanUpResource(tools); + }, e -> { log.warn("Failed to generate LLM summary", e); })); + } else { + // Use traditional approach + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + } + } + + private void sendTraditionalMaxIterationsResponse( String sessionId, ActionListener listener, String question, @@ -908,6 +995,64 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + if (stepsSummary == null || stepsSummary.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + summaryParams.put("inputs", summaryPrompt); + summaryParams.put("prompt", summaryPrompt); + summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); + + ActionRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary != null) { + listener.onResponse(summary); + } else { + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + } + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + String outputString = outputToOutputString(response.getOutput()); + if (outputString != null && !outputString.trim().isEmpty()) { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } + } + } + return null; + } catch (Exception e) { + log.warn("Failed to extract summary from response", e); + return null; + } + } + private void saveMessage( ConversationIndexMemory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 9ff33ecaa9..234699b37c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -140,4 +140,7 @@ public class PromptTemplate { - Avoid making assumptions and relying on implicit knowledge. - Your response must be self-contained and ready for the planner to use without modification. Never end with a question. - Break complex searches into simpler queries when appropriate."""; + + public static final String SUMMARY_PROMPT_TEMPLATE = + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..98977adac0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1171,4 +1171,93 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testMaxIterationsWithSummaryEnabled() { + // Create LLM spec with max_iteration = 1 to simplify test + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset and setup fresh mocks + Mockito.reset(client); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM response + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) + .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains summary message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertTrue( + response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + ); + assertTrue(response.contains("Summary: Analysis step was attempted")); + } + + @Test + public void testMaxIterationsWithSummaryDisabled() { + // Create LLM spec with max_iteration = 1 and summary disabled + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset client mock for this test + Mockito.reset(client); + // Mock LLM response that doesn't contain final_answer + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains traditional max iterations message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); + } } From 4b94b46f975939bca8b1cc93cbaac53c006c7fc1 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 15:01:30 +0800 Subject: [PATCH 02/24] fix:recover build.gradle Signed-off-by: Jiaru Jiang --- build.gradle | 2 -- 1 file changed, 2 deletions(-) diff --git a/build.gradle b/build.gradle index 2df2f50656..91267a067b 100644 --- a/build.gradle +++ b/build.gradle @@ -44,7 +44,6 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) - force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -96,7 +95,6 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" - resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}" resolutionStrategy.force "io.netty:netty-codec:${versions.netty}" From ff3a90656a0267d06a06e278db4e3fc978d0085d Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:12:14 +0800 Subject: [PATCH 03/24] add:increase test coverage Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 2 +- .../agent/MLChatAgentRunnerTest.java | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 62f51db848..1f50789222 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1034,7 +1034,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan } } - private String extractSummaryFromResponse(MLTaskResponse response) { + public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 98977adac0..64406500f1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1260,4 +1260,32 @@ public void testMaxIterationsWithSummaryDisabled() { String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); } + + @Test + public void testExtractSummaryFromResponse() { + MLTaskResponse response = MLTaskResponse.builder() + .output(ModelTensorOutput.builder() + .mlModelOutputs(Arrays.asList( + ModelTensors.builder() + .mlModelTensors(Arrays.asList( + ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "Valid summary text")) + .build())) + .build())) + .build()) + .build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Valid summary text", result); + } + + @Test + public void testGenerateLLMSummaryWithNullSteps() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + ActionListener listener = Mockito.mock(ActionListener.class); + + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } } From c7ce96b276bd37114ffcd36e5fbf6540d67c747f Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:14:24 +0800 Subject: [PATCH 04/24] fix:spotlessApply Signed-off-by: Jiaru Jiang --- .../agent/MLChatAgentRunnerTest.java | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 64406500f1..b9814980e2 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1263,18 +1263,29 @@ public void testMaxIterationsWithSummaryDisabled() { @Test public void testExtractSummaryFromResponse() { - MLTaskResponse response = MLTaskResponse.builder() - .output(ModelTensorOutput.builder() - .mlModelOutputs(Arrays.asList( - ModelTensors.builder() - .mlModelTensors(Arrays.asList( - ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "Valid summary text")) - .build())) - .build())) - .build()) + MLTaskResponse response = MLTaskResponse + .builder() + .output( + ModelTensorOutput + .builder() + .mlModelOutputs( + Arrays + .asList( + ModelTensors + .builder() + .mlModelTensors( + Arrays + .asList( + ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() + ) + ) + .build() + ) + ) + .build() + ) .build(); - + String result = mlChatAgentRunner.extractSummaryFromResponse(response); assertEquals("Valid summary text", result); } @@ -1283,9 +1294,9 @@ public void testExtractSummaryFromResponse() { public void testGenerateLLMSummaryWithNullSteps() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); ActionListener listener = Mockito.mock(ActionListener.class); - + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); - + verify(listener).onFailure(any(IllegalArgumentException.class)); } } From 9b3e14265dd80b65ba4b93b7219554a45d8f1ea1 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:01:35 +0800 Subject: [PATCH 05/24] fix:use traceTensor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 1f50789222..32aa4ae76c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -127,8 +127,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; - private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = - "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -325,8 +324,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); List interactions = new CopyOnWriteArrayList<>(); - List executionSteps = new CopyOnWriteArrayList<>(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -384,17 +381,6 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); - // Record execution step for summary - if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { - executionSteps.add(String.format("Thought: %s", thought.trim())); - } - if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { - String actionDesc = actionInput != null && !"null".equals(actionInput) - ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) - : String.format("Action: %s", action.trim()); - executionSteps.add(actionDesc); - } - traceTensors .add( ModelTensors @@ -431,7 +417,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -486,10 +471,6 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - // Record tool result for summary - String outputSummary = outputToOutputString(filteredOutput); - executionSteps.add(String.format("Result: %s", outputSummary)); - saveTraceData( conversationIndexMemory, "ReAct", @@ -540,7 +521,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -917,14 +897,13 @@ private void handleMaxIterationsReached( int maxIterations, Map tools, Map parameters, - List executionSteps, LLMSpec llmSpec, String tenantId ) { boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); - if (shouldSummarize && !executionSteps.isEmpty()) { - generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + if (shouldSummarize && !traceTensors.isEmpty()) { + generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, @@ -995,7 +974,7 @@ private void sendTraditionalMaxIterationsResponse( cleanUpResource(tools); } - void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { if (stepsSummary == null || stepsSummary.isEmpty()) { listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); return; @@ -1006,7 +985,13 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan if (llmSpec.getParameters() != null) { summaryParams.putAll(llmSpec.getParameters()); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + + // Convert ModelTensors to strings before joining + List stepStrings = new ArrayList<>(); + for (ModelTensors tensor : stepsSummary) { + stepStrings.add(outputToOutputString(tensor)); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From 763de9dcdf422be5b48dc0a8d6dbc972a291e8b3 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:13:38 +0800 Subject: [PATCH 06/24] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 32aa4ae76c..cc7f4289bc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -904,7 +904,7 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, listener, From 1eb03eebfe194f8306268523c133f627cc189e44 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:53:50 +0800 Subject: [PATCH 07/24] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index cc7f4289bc..5494e97649 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -991,7 +991,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String for (ModelTensors tensor : stepsSummary) { stepStrings.add(outputToOutputString(tensor)); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From f27def8c543670f4f95b1e9b081bef5a8b6ae66d Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:17:45 +0800 Subject: [PATCH 08/24] fix:reuse sendTraditionalMaxIterationsResponse method Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 5494e97649..4359617372 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -904,8 +904,9 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - sendFinalAnswer( + String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + AtomicReference summaryThought = new AtomicReference<>(summaryResponse); + sendTraditionalMaxIterationsResponse( sessionId, listener, question, @@ -916,12 +917,12 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + summaryThought, + 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + tools ); - cleanUpResource(tools); }, e -> { log.warn("Failed to generate LLM summary", e); })); } else { - // Use traditional approach sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -955,9 +956,16 @@ private void sendTraditionalMaxIterationsResponse( int maxIterations, Map tools ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + String incompleteResponse; + if (maxIterations == 0) { + // 直接使用 lastThought 中的完整消息(用于摘要情况) + incompleteResponse = lastThought.get(); + } else { + // 传统格式化(用于普通情况) + incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + } sendFinalAnswer( sessionId, listener, @@ -993,7 +1001,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); - summaryParams.put("prompt", summaryPrompt); + summaryParams.put(PROMPT, summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( From 81a491c9a0087379a2763ee5b6dc7f1b4d654e75 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:22:39 +0800 Subject: [PATCH 09/24] fix:remove useless comment Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4359617372..83aa3ad333 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -918,7 +918,7 @@ private void handleMaxIterationsReached( traceNumber, additionalInfo, summaryThought, - 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + 0, tools ); }, e -> { log.warn("Failed to generate LLM summary", e); })); @@ -958,10 +958,8 @@ private void sendTraditionalMaxIterationsResponse( ) { String incompleteResponse; if (maxIterations == 0) { - // 直接使用 lastThought 中的完整消息(用于摘要情况) incompleteResponse = lastThought.get(); } else { - // 传统格式化(用于普通情况) incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); From a1f6f2ef3d70e3c693e1bc45226c9003fd42f608 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 13:39:03 +0800 Subject: [PATCH 10/24] fix: delete stop Signed-off-by: Jiaru Jiang --- .../opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java | 1 - 1 file changed, 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 83aa3ad333..172b5e670f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1000,7 +1000,6 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); - summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From c5a3664b536c575006f1a7a2ef84e2262edce0ae Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:06:51 +0800 Subject: [PATCH 11/24] fix: refactor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 172b5e670f..6225ba40e8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -905,7 +905,6 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - AtomicReference summaryThought = new AtomicReference<>(summaryResponse); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -917,12 +916,18 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - summaryThought, - 0, + summaryResponse, tools ); - }, e -> { log.warn("Failed to generate LLM summary", e); })); + }, e -> { + log.error("Failed to generate LLM summary", e); + listener.onFailure(e); + cleanUpResource(tools); + })); } else { + String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -934,8 +939,7 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - lastThought, - maxIterations, + response, tools ); } @@ -952,18 +956,9 @@ private void sendTraditionalMaxIterationsResponse( ConversationIndexMemory conversationIndexMemory, AtomicInteger traceNumber, Map additionalInfo, - AtomicReference lastThought, - int maxIterations, + String response, Map tools ) { - String incompleteResponse; - if (maxIterations == 0) { - incompleteResponse = lastThought.get(); - } else { - incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - } sendFinalAnswer( sessionId, listener, @@ -975,7 +970,7 @@ private void sendTraditionalMaxIterationsResponse( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + response ); cleanUpResource(tools); } @@ -1038,8 +1033,8 @@ public String extractSummaryFromResponse(MLTaskResponse response) { } return null; } catch (Exception e) { - log.warn("Failed to extract summary from response", e); - return null; + log.error("Failed to extract summary from response", e); + throw new RuntimeException("Failed to extract summary from response", e); } } From 2de8e8612c0581d32b6b7ba9757430e65a597965 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:55:20 +0800 Subject: [PATCH 12/24] fix: json serialization Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 24 ++++++++++++++----- .../agent/MLChatAgentRunnerTest.java | 17 +++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 6225ba40e8..0f3981af66 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -990,7 +990,15 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String // Convert ModelTensors to strings before joining List stepStrings = new ArrayList<>(); for (ModelTensors tensor : stepsSummary) { - stepStrings.add(outputToOutputString(tensor)); + if (tensor != null && tensor.getMlModelTensors() != null) { + for (ModelTensor modelTensor : tensor.getMlModelTensors()) { + if (modelTensor.getResult() != null) { + stepStrings.add(modelTensor.getResult()); + } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { + stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); + } + } + } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); @@ -1023,12 +1031,16 @@ public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { - Map dataMap = gson.fromJson(outputString, Map.class); - if (dataMap.containsKey("response")) { - String summary = String.valueOf(dataMap.get("response")); - if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { - return summary.trim(); + try { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } } + } catch (Exception jsonException) { + return outputString.trim(); } } return null; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index b9814980e2..0f233f7fa7 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1188,11 +1188,24 @@ public void testMaxIterationsWithSummaryEnabled() { // Reset and setup fresh mocks Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + // First call: LLM response without final_answer to trigger max iterations - // Second call: Summary LLM response + // Second call: Summary LLM response with result field instead of dataAsMap Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) - .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); From fab05352716cc81fe03617444f5eb9ede47c4ea8 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 13 Oct 2025 14:00:22 +0800 Subject: [PATCH 13/24] fix: parameter Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 0f3981af66..ee0f03eebd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1001,8 +1001,8 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); - summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From 6cef967200ba5b9d7618ce80dd4006257402365e Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 11:45:17 +0800 Subject: [PATCH 14/24] delete:summarize_when_max_iteration Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index ee0f03eebd..27f558dcb3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -123,7 +123,6 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; - public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration"; private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; @@ -900,9 +899,6 @@ private void handleMaxIterationsReached( LLMSpec llmSpec, String tenantId ) { - boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); - - if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendTraditionalMaxIterationsResponse( @@ -924,25 +920,6 @@ private void handleMaxIterationsReached( listener.onFailure(e); cleanUpResource(tools); })); - } else { - String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - sendTraditionalMaxIterationsResponse( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - conversationIndexMemory, - traceNumber, - additionalInfo, - response, - tools - ); - } } private void sendTraditionalMaxIterationsResponse( From 4a739587ad8b1d42d518b24936491b8e26d493d9 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 14:08:42 +0800 Subject: [PATCH 15/24] fix: unit test Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 58 ++++++++++++------- .../agent/MLChatAgentRunnerTest.java | 19 +++--- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 27f558dcb3..4668153ba2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -899,27 +899,43 @@ private void handleMaxIterationsReached( LLMSpec llmSpec, String tenantId ) { - generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - sendTraditionalMaxIterationsResponse( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - conversationIndexMemory, - traceNumber, - additionalInfo, - summaryResponse, - tools - ); - }, e -> { - log.error("Failed to generate LLM summary", e); - listener.onFailure(e); - cleanUpResource(tools); - })); + ActionListener responseListener = ActionListener.wrap(response -> { + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + response, + tools + ); + }, listener::onFailure); + + generateLLMSummary( + traceTensors, + llmSpec, + tenantId, + ActionListener + .wrap( + summary -> responseListener + .onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)), + e -> { + log.error("Failed to generate LLM summary, using fallback strategy", e); + String fallbackResponse = (lastThought.get() != null + && !lastThought.get().isEmpty() + && !"null".equals(lastThought.get())) + ? String + .format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + responseListener.onResponse(fallbackResponse); + } + ) + ); } private void sendTraditionalMaxIterationsResponse( diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0f233f7fa7..4c2e4a43e9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1211,8 +1211,6 @@ public void testMaxIterationsWithSummaryEnabled() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); // Verify response is captured @@ -1234,7 +1232,7 @@ public void testMaxIterationsWithSummaryEnabled() { @Test public void testMaxIterationsWithSummaryDisabled() { - // Create LLM spec with max_iteration = 1 and summary disabled + // Create LLM spec with max_iteration = 1 LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); final MLAgent mlAgent = MLAgent @@ -1248,15 +1246,16 @@ public void testMaxIterationsWithSummaryDisabled() { // Reset client mock for this test Mockito.reset(client); - // Mock LLM response that doesn't contain final_answer - Mockito - .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM fails + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); mlChatAgentRunner.run(mlAgent, params, agentActionListener); @@ -1269,7 +1268,7 @@ public void testMaxIterationsWithSummaryDisabled() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains traditional max iterations message + // Verify the response uses fallback strategy with last thought String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); } From 6f1fc22bc0aacd48e755741b82ec9925d360a231 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 17:37:35 +0800 Subject: [PATCH 16/24] fix: summary message Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 9 +++++---- .../engine/algorithms/agent/MLChatAgentRunnerTest.java | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4668153ba2..be3927f5ab 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -126,7 +126,8 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; - private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + + ". Here's a summary of the steps completed so far:\n\n%s"; private Client client; private Settings settings; @@ -1009,11 +1010,11 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String ); client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { String summary = extractSummaryFromResponse(response); - if (summary != null) { - listener.onResponse(summary); - } else { + if (summary == null) { listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; } + listener.onResponse(summary); }, listener::onFailure)); } catch (Exception e) { listener.onFailure(e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 4c2e4a43e9..9a67f09f50 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1225,7 +1225,10 @@ public void testMaxIterationsWithSummaryEnabled() { // Verify the response contains summary message String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertTrue( - response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) ); assertTrue(response.contains("Summary: Analysis step was attempted")); } From a4048f9a5e25d944da2f4c3e3664678332cd81a2 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 23:27:55 +0800 Subject: [PATCH 17/24] fix: summary prompt Signed-off-by: Jiaru Jiang --- build.gradle | 2 +- .../engine/algorithms/agent/MLChatAgentRunner.java | 12 +----------- .../ml/engine/algorithms/agent/PromptTemplate.java | 2 +- .../algorithms/agent/MLChatAgentRunnerTest.java | 7 +------ .../ml/helper/ModelAccessControlHelper.java | 2 +- 5 files changed, 5 insertions(+), 20 deletions(-) diff --git a/build.gradle b/build.gradle index 91267a067b..854100e791 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") asm_version = "9.7" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index be3927f5ab..06b180a3fe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1025,17 +1025,7 @@ public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { - try { - Map dataMap = gson.fromJson(outputString, Map.class); - if (dataMap.containsKey("response")) { - String summary = String.valueOf(dataMap.get("response")); - if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { - return summary.trim(); - } - } - } catch (Exception jsonException) { - return outputString.trim(); - } + return outputString.trim(); } return null; } catch (Exception e) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 234699b37c..1ea97f1ebc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -142,5 +142,5 @@ public class PromptTemplate { - Break complex searches into simpler queries when appropriate."""; public static final String SUMMARY_PROMPT_TEMPLATE = - "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 9a67f09f50..4c83e63adb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1288,12 +1288,7 @@ public void testExtractSummaryFromResponse() { .asList( ModelTensors .builder() - .mlModelTensors( - Arrays - .asList( - ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() - ) - ) + .mlModelTensors(Arrays.asList(ModelTensor.builder().result("Valid summary text").build())) .build() ) ) diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index 5661c01fdc..67817adecd 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -295,7 +295,7 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti */ public static boolean shouldUseResourceAuthz(String resourceType) { var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); - return client != null && client.isFeatureEnabledForType(resourceType); + return client != null; } public boolean skipModelAccessControl(User user) { From 5d2a2abcc737351766f3ea87ceb9981ab98f9f01 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 30 Oct 2025 10:27:25 +0800 Subject: [PATCH 18/24] fix: configuration file Signed-off-by: Jiaru Jiang --- build.gradle | 2 +- .../java/org/opensearch/ml/helper/ModelAccessControlHelper.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 854100e791..91267a067b 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") asm_version = "9.7" diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index 67817adecd..5661c01fdc 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -295,7 +295,7 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti */ public static boolean shouldUseResourceAuthz(String resourceType) { var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); - return client != null; + return client != null && client.isFeatureEnabledForType(resourceType); } public boolean skipModelAccessControl(User user) { From 0e08b92b89b10c3efa24fee39d1c4e88fd667424 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 30 Oct 2025 17:20:14 +0800 Subject: [PATCH 19/24] add: planner max step summary Signed-off-by: Jiaru Jiang --- .../MLPlanExecuteAndReflectAgentRunner.java | 145 ++++++++++++++--- .../agent/MLChatAgentRunnerTest.java | 40 +++-- ...LPlanExecuteAndReflectAgentRunnerTest.java | 148 +++++++++++++++++- 3 files changed, 303 insertions(+), 30 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 762b60ca5c..a000597705 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -33,6 +33,7 @@ import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.util.ArrayList; import java.util.HashMap; @@ -363,27 +364,8 @@ private void executePlanningLoop( int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); String parentInteractionId = allParams.get(MLAgentExecutor.PARENT_INTERACTION_ID); - // completedSteps stores the step and its result, hence divide by 2 to find total steps completed - // on reaching max iteration, update parent interaction question with last executed step rather than task to allow continue using - // memory_id if (stepsExecuted >= maxSteps) { - String finalResult = String - .format( - "Max Steps Limit Reached. Use memory_id with same task to restart. \n " - + "Last executed step: %s, \n " - + "Last executed step result: %s", - completedSteps.get(completedSteps.size() - 2), - completedSteps.getLast() - ); - saveAndReturnFinalResult( - (ConversationIndexMemory) memory, - parentInteractionId, - allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), - allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), - finalResult, - null, - finalListener - ); + handleMaxStepsReached(llm, allParams, completedSteps, memory, parentInteractionId, finalListener); return; } @@ -740,4 +722,127 @@ static List createModelTensors( Map getTaskUpdates() { return taskUpdates; } + + private void handleMaxStepsReached( + LLMSpec llm, + Map allParams, + List completedSteps, + Memory memory, + String parentInteractionId, + ActionListener finalListener + ) { + int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); + log.info("[SUMMARY] Max steps reached. Completed steps: {}", completedSteps.size()); + + ActionListener responseListener = ActionListener.wrap(response -> { + saveAndReturnFinalResult( + (ConversationIndexMemory) memory, + parentInteractionId, + allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), + allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), + response, + null, + finalListener + ); + }, finalListener::onFailure); + + generateSummary(llm, completedSteps, allParams.get(TENANT_ID_FIELD), ActionListener.wrap(summary -> { + log.info("Summary generated successfully"); + responseListener + .onResponse( + String.format("Max Steps Limit (%d) Reached. Here's a summary of the steps completed so far:\n\n%s", maxSteps, summary) + ); + }, e -> { + log.error("Summary generation failed, using fallback", e); + String fallbackResult = completedSteps.isEmpty() || completedSteps.size() < 2 + ? String.format("Max Steps Limit (%d) Reached. Use memory_id with same task to restart.", maxSteps) + : String + .format( + "Max Steps Limit (%d) Reached. Use memory_id with same task to restart. \n " + + "Last executed step: %s, \n " + + "Last executed step result: %s", + maxSteps, + completedSteps.get(completedSteps.size() - 2), + completedSteps.getLast() + ); + responseListener.onResponse(fallbackResult); + })); + } + + private void generateSummary(LLMSpec llmSpec, List completedSteps, String tenantId, ActionListener listener) { + if (completedSteps == null || completedSteps.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", completedSteps)); + summaryParams.put(PROMPT_FIELD, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); + + MLPredictionTaskRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary == null || summary.trim().isEmpty()) { + log.error("Extracted summary is empty"); + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; + } + listener.onResponse(summary); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output != null && output.getMlModelOutputs() != null && !output.getMlModelOutputs().isEmpty()) { + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors != null && tensors.getMlModelTensors() != null && !tensors.getMlModelTensors().isEmpty()) { + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + if (tensor.getDataAsMap() != null) { + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, "$.output.message.content[0].text"); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + } + log + .error( + "Summary generate error. No result/response field. Available: {}", + tensor.getDataAsMap() != null ? tensor.getDataAsMap().keySet() : "null" + ); + } + } + return null; + } catch (Exception e) { + log.error("Summary extraction failed", e); + throw new RuntimeException("Failed to extract summary from response", e); + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 4c83e63adb..ec480661e9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1031,11 +1031,19 @@ public void testMaxIterationsReached() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response that doesn't contain final_answer to force max iterations - Mockito - .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response without final_answer to force max iterations + // Second call: Summary LLM response + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("The agent attempted to use the first tool").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); @@ -1051,9 +1059,15 @@ public void testMaxIterationsReached() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains max iterations message + // Verify the response contains max iterations message with summary String response = (String) agentOutput.get(0).getDataAsMap().get("response"); - assertEquals("Agent reached maximum iterations (1) without completing the task", response); + assertTrue( + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) + ); + assertTrue(response.contains("The agent attempted to use the first tool")); } @Test @@ -1070,9 +1084,17 @@ public void testMaxIterationsReachedWithValidThought() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response with valid thought + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response with valid thought to trigger max iterations + // Second call: Summary LLM fails to trigger fallback Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); @@ -1090,7 +1112,7 @@ public void testMaxIterationsReachedWithValidThought() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains the last valid thought instead of max iterations message + // Verify the response contains the last valid thought (fallback when summary fails) String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals( "Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool", diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 7ed4e91b1c..df1cf04305 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -382,7 +382,153 @@ public void testMessageHistoryLimits() { assertEquals("3", executorParams.get("message_history_limit")); } - // ToDo: add test case for when max steps is reached + @Test + public void testMaxStepsReachedWithSummary() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary of work done").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Summary of work done")); + } + + @Test + public void testMaxStepsReachedWithSummaryGeneration() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Generated summary of completed steps").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Generated summary of completed steps")); + } + + @Test + public void testMaxStepsReachedWithSummaryFailure() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Summary generation failed")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testMaxStepsReachedWithEmptyCompletedSteps() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Collections.emptyList()); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); From 4c48e653700d014bd926b17ed05da1b5a8edd18b Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 31 Oct 2025 10:29:44 +0800 Subject: [PATCH 20/24] fix: system prompt Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 06b180a3fe..0e87bd31fb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -996,7 +996,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put(PROMPT, summaryPrompt); - summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From e5ab87bb5ec014d2bf83fe805fec521e646467c0 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 4 Nov 2025 14:17:15 +0800 Subject: [PATCH 21/24] fix: parseLLMOutput Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 36 +++++++++++-- .../MLPlanExecuteAndReflectAgentRunner.java | 53 ++++++++++--------- .../TransportUpdateModelGroupActionTests.java | 4 +- 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 0e87bd31fb..789afca248 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; @@ -84,6 +85,7 @@ import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; +import com.jayway.jsonpath.JsonPath; import lombok.Data; import lombok.NoArgsConstructor; @@ -1023,10 +1025,38 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String public String extractSummaryFromResponse(MLTaskResponse response) { try { - String outputString = outputToOutputString(response.getOutput()); - if (outputString != null && !outputString.trim().isEmpty()) { - return outputString.trim(); + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey("response")) { + return String.valueOf(dataMap.get("response")).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); return null; } catch (Exception e) { log.error("Failed to extract summary from response", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index a000597705..08f9e6dca8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -813,32 +813,37 @@ private void generateSummary(LLMSpec llmSpec, List completedSteps, Strin private String extractSummaryFromResponse(MLTaskResponse response) { try { ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); - if (output != null && output.getMlModelOutputs() != null && !output.getMlModelOutputs().isEmpty()) { - ModelTensors tensors = output.getMlModelOutputs().getFirst(); - if (tensors != null && tensors.getMlModelTensors() != null && !tensors.getMlModelTensors().isEmpty()) { - ModelTensor tensor = tensors.getMlModelTensors().getFirst(); - if (tensor.getResult() != null) { - return tensor.getResult().trim(); - } - if (tensor.getDataAsMap() != null) { - Map dataMap = tensor.getDataAsMap(); - if (dataMap.containsKey(RESPONSE_FIELD)) { - return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); - } - if (dataMap.containsKey("output")) { - Object outputObj = JsonPath.read(dataMap, "$.output.message.content[0].text"); - if (outputObj != null) { - return String.valueOf(outputObj).trim(); - } - } - } - log - .error( - "Summary generate error. No result/response field. Available: {}", - tensor.getDataAsMap() != null ? tensor.getDataAsMap().keySet() : "null" - ); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; + } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); } } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); return null; } catch (Exception e) { log.error("Summary extraction failed", e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 5d7d1d5917..1b2c772bfd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -464,7 +464,7 @@ public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() t // Enable RSC fast-path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + // when(rsc.isFeatureEnabledForType(any())).thenReturn(true); // No ACL changes in request (so even legacy would pass, but we won't go there). MLUpdateModelGroupRequest req = prepareRequest(null, null, null); @@ -486,7 +486,7 @@ public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() t // RSC feature on, but type disabled → legacy path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + // when(rsc.isFeatureEnabledForType(any())).thenReturn(false); // Allow legacy validation to pass: // security/model-access-control enabled: From c776066b7b0a953d2eb2a196db1b96eb4712f850 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 4 Nov 2025 15:16:30 +0800 Subject: [PATCH 22/24] add: test cases Signed-off-by: Jiaru Jiang --- .../agent/MLChatAgentRunnerTest.java | 37 +++++ ...LPlanExecuteAndReflectAgentRunnerTest.java | 141 ++++++++++++++++++ .../TransportUpdateModelGroupActionTests.java | 4 +- 3 files changed, 180 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index ec480661e9..44d05183ab 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1331,4 +1331,41 @@ public void testGenerateLLMSummaryWithNullSteps() { verify(listener).onFailure(any(IllegalArgumentException.class)); } + + @Test + public void testExtractSummaryFromResponse_WithResponseField() { + Map dataMap = new HashMap<>(); + dataMap.put("response", "Summary from response field"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Summary from response field", result); + } + + @Test + public void testExtractSummaryFromResponse_WithNullDataMap() { + ModelTensor tensor = ModelTensor.builder().build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } + + @Test + public void testExtractSummaryFromResponse_WithEmptyDataMap() { + Map dataMap = new HashMap<>(); + dataMap.put("other_field", "some value"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index df1cf04305..f1906e8630 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -1001,4 +1001,145 @@ public void testExecutionWithNullStepResult() { // Verify that onFailure was called with the expected exception verify(agentActionListener).onFailure(any(IllegalStateException.class)); } + + @Test + public void testMaxStepsWithSingleCompletedStep() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Arrays.asList(Interaction.builder().id("i1").input("step1").response("").build())); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testSummaryExtractionWithResultField() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().result("Summary from result").build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Summary from result")); + } + + @Test + public void testSummaryExtractionWithEmptyResponse() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", " ")).build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit")); + } + + @Test + public void testSummaryExtractionWithNullOutput() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + when(mlTaskResponse.getOutput()).thenReturn(null); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(any()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 1b2c772bfd..5d7d1d5917 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -464,7 +464,7 @@ public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() t // Enable RSC fast-path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - // when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); // No ACL changes in request (so even legacy would pass, but we won't go there). MLUpdateModelGroupRequest req = prepareRequest(null, null, null); @@ -486,7 +486,7 @@ public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() t // RSC feature on, but type disabled → legacy path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - // when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); // Allow legacy validation to pass: // security/model-access-control enabled: From d8757bcb8f9f848cbf0a60ea8ee03fbdb79938dc Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 6 Nov 2025 16:01:46 +0800 Subject: [PATCH 23/24] add: test cases for fallback Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 3 -- .../agent/MLChatAgentRunnerTest.java | 47 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 789afca248..ca5975e080 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -418,7 +418,6 @@ private void runReAct( lastThought, maxIterations, tools, - tmpParameters, llm, tenantId ); @@ -522,7 +521,6 @@ private void runReAct( lastThought, maxIterations, tools, - tmpParameters, llm, tenantId ); @@ -898,7 +896,6 @@ private void handleMaxIterationsReached( AtomicReference lastThought, int maxIterations, Map tools, - Map parameters, LLMSpec llmSpec, String tenantId ) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 44d05183ab..8c0a288746 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1368,4 +1368,51 @@ public void testExtractSummaryFromResponse_WithEmptyDataMap() { String result = mlChatAgentRunner.extractSummaryFromResponse(response); assertEquals(null, result); } + + @Test + public void testExtractSummaryFromResponse_ThrowsException_FallbackStrategyUsed() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "Analyzing the problem", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + Map invalidDataMap = new HashMap<>(); + invalidDataMap.put("output", new HashMap<>()); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(invalidDataMap).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: Analyzing the problem", response); + } } From 8b802a6508b20a73d5718ca4509f4d2f39638c66 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 18 Nov 2025 12:59:40 +0800 Subject: [PATCH 24/24] fix:import Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index ca5975e080..89c0492a89 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -59,10 +59,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -70,6 +73,8 @@ import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; @@ -1007,8 +1012,8 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String null, tenantId ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { - String summary = extractSummaryFromResponse(response); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + String summary = extractSummaryFromResponse(mlTaskResponse); if (summary == null) { listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); return;