Skip to content

Commit f022728

Browse files
committed
Update the POST_TOOL hook emit saving to agentic memory (#4408)
* Fix POST_TOOL hook interaction updates and add tenant ID support Signed-off-by: Mingshi Liu <mingshl@amazon.com> - Fix POST_TOOL hook to return full ContextManagerContext like PRE_LLM hook - Update MLChatAgentRunner to properly handle interaction updates from POST_TOOL hook - Ensure interactions list and tmpParameters.INTERACTIONS stay synchronized - Add tenant ID support to MLPredictionTaskRequest in ModelGuardrail and SummarizationManager Signed-off-by: Mingshi Liu <mingshl@amazon.com> * fix error message escaping Signed-off-by: Mingshi Liu <mingshl@amazon.com> * consolicate post_hook logic Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 7eb18a9 commit f022728

File tree

4 files changed

+45
-39
lines changed

4 files changed

+45
-39
lines changed

common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static java.util.concurrent.TimeUnit.SECONDS;
99
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
10+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
1011
import static org.opensearch.ml.common.utils.StringUtils.gson;
1112

1213
import java.io.IOException;
@@ -125,13 +126,16 @@ public Boolean validate(String in, Map<String, String> parameters) {
125126
guardrailModelParams.put("response_filter", responseFilter);
126127
}
127128
log.info("Guardrail resFilter: {}", responseFilter);
129+
String tenantId = parameters != null ? parameters.get(TENANT_ID_FIELD) : null;
128130
ActionRequest request = new MLPredictionTaskRequest(
129131
modelId,
130132
RemoteInferenceMLInput
131133
.builder()
132134
.algorithm(FunctionName.REMOTE)
133135
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build())
134-
.build()
136+
.build(),
137+
null,
138+
tenantId
135139
);
136140
client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch));
137141
try {

ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,36 +104,33 @@ public static ContextManagerContext buildContextManagerContext(
104104
return builder.build();
105105
}
106106

107-
public static Object emitPostToolHook(
107+
public static ContextManagerContext emitPostToolHook(
108108
Object toolOutput,
109109
Map<String, String> parameters,
110110
List<MLToolSpec> toolSpecs,
111111
Memory memory,
112112
HookRegistry hookRegistry
113113
) {
114+
ContextManagerContext context = buildContextManagerContextForToolOutput(
115+
StringUtils.toJson(toolOutput),
116+
parameters,
117+
toolSpecs,
118+
memory
119+
);
120+
114121
if (hookRegistry != null) {
115122
try {
116123
if (toolOutput == null) {
117124
log.warn("Tool output is null, skipping POST_TOOL hook");
118-
return null;
125+
return context;
119126
}
120-
ContextManagerContext context = buildContextManagerContextForToolOutput(
121-
StringUtils.toJson(toolOutput),
122-
parameters,
123-
toolSpecs,
124-
memory
125-
);
126127
EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>());
127128
hookRegistry.emit(event);
128-
129-
Object processedOutput = extractProcessedToolOutput(context);
130-
return processedOutput != null ? processedOutput : toolOutput;
131129
} catch (Exception e) {
132130
log.error("Failed to emit POST_TOOL hook event", e);
133-
return toolOutput;
134131
}
135132
}
136-
return toolOutput;
133+
return context;
137134
}
138135

139136
public static ContextManagerContext emitPreLLMHook(

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ private void runReAct(
475475
((ActionListener<Object>) nextStepListener).onResponse(res);
476476
}
477477
} else {
478-
// filteredOutput is the POST Tool output
478+
// output is now the processed output from POST_TOOL hook in runTool
479479
Object filteredOutput = filterToolOutput(lastToolParams, output);
480480
addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput);
481481

@@ -488,6 +488,7 @@ private void runReAct(
488488
);
489489
scratchpadBuilder.append(toolResponse).append("\n\n");
490490

491+
// Save trace with processed output
491492
saveTraceData(
492493
conversationIndexMemory,
493494
"ReAct",
@@ -669,26 +670,23 @@ private static void runTool(
669670
try {
670671
String finalAction = action;
671672
ActionListener<Object> toolListener = ActionListener.wrap(r -> {
672-
if (functionCalling != null) {
673-
String outputResponse = parseResponse(filterToolOutput(toolParams, r));
673+
// Emit POST_TOOL hook event - common for all tool executions
674+
List<MLToolSpec> postToolSpecs = new ArrayList<>(toolSpecMap.values());
675+
ContextManagerContext contextAfterPostTool = AgentContextUtil
676+
.emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry);
674677

675-
// Emit POST_TOOL hook event after tool execution and process current tool
676-
// output
677-
List<MLToolSpec> postToolSpecs = new ArrayList<>(toolSpecMap.values());
678-
String outputResponseAfterHook = AgentContextUtil
679-
.emitPostToolHook(outputResponse, tmpParameters, postToolSpecs, null, hookRegistry)
680-
.toString();
678+
// Extract processed output from POST_TOOL hook
679+
String processedToolOutput = contextAfterPostTool.getParameters().get("_current_tool_output");
680+
Object processedOutput = processedToolOutput != null ? processedToolOutput : r;
681681

682+
if (functionCalling != null) {
683+
String outputResponse = parseResponse(filterToolOutput(toolParams, processedOutput));
682684
List<Map<String, Object>> toolResults = List
683-
.of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponseAfterHook)));
685+
.of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse)));
684686
List<LLMMessage> llmMessages = functionCalling.supply(toolResults);
685-
// TODO: support multiple tool calls at the same time so that multiple
686-
// LLMMessages can be generated here
687+
// TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here
687688
interactions.add(llmMessages.getFirst().getResponse());
688689
} else {
689-
// Emit POST_TOOL hook event for non-function calling path
690-
List<MLToolSpec> postToolSpecs = new ArrayList<>(toolSpecMap.values());
691-
Object processedOutput = AgentContextUtil.emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry);
692690
interactions
693691
.add(
694692
substitute(
@@ -698,25 +696,25 @@ private static void runTool(
698696
)
699697
);
700698
}
701-
nextStepListener.onResponse(r);
699+
nextStepListener.onResponse(processedOutput);
702700
}, e -> {
703701
interactions
704702
.add(
705703
substitute(
706704
tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE),
707-
Map.of(TOOL_CALL_ID, toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()),
705+
Map
706+
.of(
707+
TOOL_CALL_ID,
708+
toolCallId,
709+
"tool_response",
710+
"Tool " + action + " failed: " + StringUtils.processTextDoc(e.getMessage())
711+
),
708712
INTERACTIONS_PREFIX
709713
)
710714
);
711715
nextStepListener
712716
.onResponse(
713-
String
714-
.format(
715-
Locale.ROOT,
716-
"Failed to run the tool %s with the error message %s.",
717-
finalAction,
718-
e.getMessage().replaceAll("\\n", "\n")
719-
)
717+
String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage())
720718
);
721719
});
722720
if (tools.get(action) instanceof MLModelTool) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine.algorithms.contextmanager;
77

88
import static java.lang.Math.min;
9+
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
910
import static org.opensearch.ml.common.FunctionName.REMOTE;
1011
import static org.opensearch.ml.common.utils.StringUtils.processTextDoc;
1112
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
@@ -193,7 +194,13 @@ protected void executeSummarization(
193194
MLInput mlInput = MLInput.builder().algorithm(REMOTE).inputDataset(inputDataset).build();
194195

195196
// Create prediction request
196-
MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId(modelId).mlInput(mlInput).build();
197+
String tenantId = (String) context.getParameter(TENANT_ID_FIELD);
198+
MLPredictionTaskRequest request = MLPredictionTaskRequest
199+
.builder()
200+
.modelId(modelId)
201+
.mlInput(mlInput)
202+
.tenantId(tenantId)
203+
.build();
197204

198205
// Execute prediction
199206
ActionListener<MLTaskResponse> listener = ActionListener.wrap(response -> {

0 commit comments

Comments
 (0)