Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.ml.common.utils.ToolUtils.parseResponse;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX;
Expand All @@ -34,6 +35,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE;
import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY;

import java.security.PrivilegedActionException;
Expand All @@ -57,17 +59,22 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.memory.Message;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
Expand All @@ -83,6 +90,7 @@
import org.opensearch.transport.client.Client;

import com.google.common.annotations.VisibleForTesting;
import com.jayway.jsonpath.JsonPath;

import lombok.Data;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -125,6 +133,8 @@ public class MLChatAgentRunner implements MLAgentRunner {

private static final String DEFAULT_MAX_ITERATIONS = "10";
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE
+ ". Here's a summary of the steps completed so far:\n\n%s";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -321,7 +331,6 @@ private void runReAct(

StringBuilder scratchpadBuilder = new StringBuilder();
List<String> interactions = new CopyOnWriteArrayList<>();

StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());
Expand Down Expand Up @@ -413,7 +422,9 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -514,7 +525,9 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -887,11 +900,63 @@ private void handleMaxIterationsReached(
Map<String, Object> additionalInfo,
AtomicReference<String> lastThought,
int maxIterations,
Map<String, Tool> tools,
LLMSpec llmSpec,
String tenantId
) {
ActionListener<String> responseListener = ActionListener.wrap(response -> {
sendTraditionalMaxIterationsResponse(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
response,
tools
);
}, listener::onFailure);

generateLLMSummary(
traceTensors,
llmSpec,
tenantId,
ActionListener
.wrap(
summary -> responseListener
.onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)),
e -> {
log.error("Failed to generate LLM summary, using fallback strategy", e);
String fallbackResponse = (lastThought.get() != null
&& !lastThought.get().isEmpty()
&& !"null".equals(lastThought.get()))
? String
.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
responseListener.onResponse(fallbackResponse);
}
)
);
}

private void sendTraditionalMaxIterationsResponse(
String sessionId,
ActionListener<Object> listener,
String question,
String parentInteractionId,
boolean verbose,
boolean traceDisabled,
List<ModelTensors> traceTensors,
ConversationIndexMemory conversationIndexMemory,
AtomicInteger traceNumber,
Map<String, Object> additionalInfo,
String response,
Map<String, Tool> tools
) {
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
sendFinalAnswer(
sessionId,
listener,
Expand All @@ -903,11 +968,104 @@ private void handleMaxIterationsReached(
conversationIndexMemory,
traceNumber,
additionalInfo,
incompleteResponse
response
);
cleanUpResource(tools);
}

void generateLLMSummary(List<ModelTensors> stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener<String> listener) {
if (stepsSummary == null || stepsSummary.isEmpty()) {
listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty"));
return;
}

try {
Map<String, String> summaryParams = new HashMap<>();
if (llmSpec.getParameters() != null) {
summaryParams.putAll(llmSpec.getParameters());
}

// Convert ModelTensors to strings before joining
List<String> stepStrings = new ArrayList<>();
for (ModelTensors tensor : stepsSummary) {
if (tensor != null && tensor.getMlModelTensors() != null) {
for (ModelTensor modelTensor : tensor.getMlModelTensors()) {
if (modelTensor.getResult() != null) {
stepStrings.add(modelTensor.getResult());
} else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) {
stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response")));
}
}
}
}
String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings));
summaryParams.put(PROMPT, summaryPrompt);
summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE);

ActionRequest request = new MLPredictionTaskRequest(
llmSpec.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build())
.build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
String summary = extractSummaryFromResponse(mlTaskResponse);
if (summary == null) {
listener.onFailure(new RuntimeException("Empty or invalid LLM summary response"));
return;
}
listener.onResponse(summary);
}, listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
}

public String extractSummaryFromResponse(MLTaskResponse response) {
try {
ModelTensorOutput output = (ModelTensorOutput) response.getOutput();
if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) {
return null;
}

ModelTensors tensors = output.getMlModelOutputs().getFirst();
if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) {
return null;
}

ModelTensor tensor = tensors.getMlModelTensors().getFirst();
if (tensor.getResult() != null) {
return tensor.getResult().trim();
}

if (tensor.getDataAsMap() == null) {
return null;
}

Map<String, ?> dataMap = tensor.getDataAsMap();
if (dataMap.containsKey("response")) {
return String.valueOf(dataMap.get("response")).trim();
}

if (dataMap.containsKey("output")) {
Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER);
if (outputObj != null) {
return String.valueOf(outputObj).trim();
}
}

log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet());
return null;
} catch (Exception e) {
log.error("Failed to extract summary from response", e);
throw new RuntimeException("Failed to extract summary from response", e);
}
}

private void saveMessage(
ConversationIndexMemory memory,
String question,
Expand Down
Loading
Loading