Skip to content

Commit d97fd6a

Browse files
committed
allow context management hook register in during agent execute
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent 6a527ee commit d97fd6a

File tree

4 files changed

+269
-10
lines changed

4 files changed

+269
-10
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,12 @@ public void execute(Input input, ActionListener<Output> listener, TransportChann
209209
) {
210210
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
211211
MLAgent mlAgent = MLAgent.parse(parser);
212-
// Always create a fresh HookRegistry for agent execution
213-
// This prevents callback accumulation from previous executions
214-
HookRegistry hookRegistry = new HookRegistry();
212+
// Use existing HookRegistry from AgentMLInput if available (set by MLExecuteTaskRunner for template
213+
// references)
214+
// Otherwise create a fresh HookRegistry for agent execution
215+
final HookRegistry hookRegistry = agentMLInput.getHookRegistry() != null
216+
? agentMLInput.getHookRegistry()
217+
: new HookRegistry();
215218
if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) {
216219
listener
217220
.onFailure(

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine.algorithms.contextmanager;
77

88
import static org.opensearch.ml.common.FunctionName.REMOTE;
9+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
910
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS;
1011

1112
import java.util.ArrayList;
@@ -31,11 +32,15 @@
3132
import org.opensearch.ml.common.utils.StringUtils;
3233
import org.opensearch.transport.client.Client;
3334

35+
import com.jayway.jsonpath.JsonPath;
36+
import com.jayway.jsonpath.PathNotFoundException;
37+
3438
import lombok.extern.log4j.Log4j2;
3539

3640
/**
3741
* Context manager that implements summarization approach for tool interactions.
38-
* Summarizes older interactions while preserving recent ones to manage context window.
42+
* Summarizes older interactions while preserving recent ones to manage context
43+
* window.
3944
*/
4045
@Log4j2
4146
public class SummarizationManager implements ContextManager {
@@ -191,7 +196,7 @@ protected void executeSummarization(
191196
// Execute prediction
192197
ActionListener<MLTaskResponse> listener = ActionListener.wrap(response -> {
193198
try {
194-
String summary = extractSummaryFromResponse(response);
199+
String summary = extractSummaryFromResponse(response, context);
195200
processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalToolInteractions);
196201
} catch (Exception e) {
197202
// Fallback to default behavior
@@ -279,7 +284,7 @@ protected void processSummarizationResult(
279284
}
280285
}
281286

282-
private String extractSummaryFromResponse(MLTaskResponse response) {
287+
private String extractSummaryFromResponse(MLTaskResponse response, ContextManagerContext context) {
283288
try {
284289
MLOutput output = response.getOutput();
285290
if (output instanceof ModelTensorOutput) {
@@ -290,7 +295,38 @@ private String extractSummaryFromResponse(MLTaskResponse response) {
290295
List<ModelTensor> tensors = mlModelOutputs.get(0).getMlModelTensors();
291296
if (tensors != null && !tensors.isEmpty()) {
292297
Map<String, ?> dataAsMap = tensors.get(0).getDataAsMap();
293-
// TODO need to parse LLM response output, maybe reused how filtered output from chatAgentRunner
298+
299+
// Use LLM_RESPONSE_FILTER from agent configuration if available
300+
Map<String, String> parameters = context.getParameters();
301+
if (parameters != null
302+
&& parameters.containsKey(LLM_RESPONSE_FILTER)
303+
&& !parameters.get(LLM_RESPONSE_FILTER).isEmpty()) {
304+
try {
305+
String responseFilter = parameters.get(LLM_RESPONSE_FILTER);
306+
Object filteredResponse = JsonPath.read(dataAsMap, responseFilter);
307+
if (filteredResponse instanceof String) {
308+
String result = ((String) filteredResponse).trim();
309+
return result;
310+
} else {
311+
String result = StringUtils.toJson(filteredResponse);
312+
return result;
313+
}
314+
} catch (PathNotFoundException e) {
315+
// Fall back to default parsing
316+
} catch (Exception e) {
317+
// Fall back to default parsing
318+
}
319+
}
320+
321+
// Fallback to default parsing if no filter or filter fails
322+
if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) {
323+
Object responseObj = dataAsMap.get("response");
324+
if (responseObj instanceof String) {
325+
return ((String) responseObj).trim();
326+
}
327+
}
328+
329+
// Last resort: return JSON representation
294330
return StringUtils.toJson(dataAsMap);
295331
}
296332
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.ml.engine.algorithms.contextmanager;
77

8+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
9+
810
import java.util.ArrayList;
911
import java.util.HashMap;
1012
import java.util.List;
@@ -16,6 +18,10 @@
1618
import org.mockito.Mock;
1719
import org.mockito.MockitoAnnotations;
1820
import org.opensearch.ml.common.contextmanager.ContextManagerContext;
21+
import org.opensearch.ml.common.output.model.ModelTensor;
22+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
23+
import org.opensearch.ml.common.output.model.ModelTensors;
24+
import org.opensearch.ml.common.transport.MLTaskResponse;
1925
import org.opensearch.transport.client.Client;
2026

2127
/**
@@ -161,6 +167,159 @@ public void testProcessSummarizationResult() {
161167
Assert.assertTrue(firstOutput.contains("Test summary"));
162168
}
163169

170+
@Test
171+
public void testExtractSummaryFromResponseWithLLMResponseFilter() throws Exception {
172+
Map<String, Object> config = new HashMap<>();
173+
manager.initialize(config);
174+
175+
// Set up context with LLM_RESPONSE_FILTER
176+
Map<String, String> parameters = new HashMap<>();
177+
parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content");
178+
context.setParameters(parameters);
179+
180+
// Create mock response with OpenAI-style structure
181+
Map<String, Object> responseData = new HashMap<>();
182+
Map<String, Object> choice = new HashMap<>();
183+
Map<String, Object> message = new HashMap<>();
184+
message.put("content", "This is the extracted summary content");
185+
choice.put("message", message);
186+
responseData.put("choices", List.of(choice));
187+
188+
MLTaskResponse mockResponse = createMockMLTaskResponse(responseData);
189+
190+
// Use reflection to access the private method
191+
java.lang.reflect.Method extractMethod = SummarizationManager.class
192+
.getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class);
193+
extractMethod.setAccessible(true);
194+
195+
String result = (String) extractMethod.invoke(manager, mockResponse, context);
196+
197+
Assert.assertEquals("This is the extracted summary content", result);
198+
}
199+
200+
@Test
201+
public void testExtractSummaryFromResponseWithBedrockResponseFilter() throws Exception {
202+
Map<String, Object> config = new HashMap<>();
203+
manager.initialize(config);
204+
205+
// Set up context with Bedrock-style LLM_RESPONSE_FILTER
206+
Map<String, String> parameters = new HashMap<>();
207+
parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text");
208+
context.setParameters(parameters);
209+
210+
// Create mock response with Bedrock-style structure
211+
Map<String, Object> responseData = new HashMap<>();
212+
Map<String, Object> output = new HashMap<>();
213+
Map<String, Object> message = new HashMap<>();
214+
Map<String, Object> content = new HashMap<>();
215+
content.put("text", "Bedrock extracted summary");
216+
message.put("content", List.of(content));
217+
output.put("message", message);
218+
responseData.put("output", output);
219+
220+
MLTaskResponse mockResponse = createMockMLTaskResponse(responseData);
221+
222+
// Use reflection to access the private method
223+
java.lang.reflect.Method extractMethod = SummarizationManager.class
224+
.getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class);
225+
extractMethod.setAccessible(true);
226+
227+
String result = (String) extractMethod.invoke(manager, mockResponse, context);
228+
229+
Assert.assertEquals("Bedrock extracted summary", result);
230+
}
231+
232+
@Test
233+
public void testExtractSummaryFromResponseWithInvalidFilter() throws Exception {
234+
Map<String, Object> config = new HashMap<>();
235+
manager.initialize(config);
236+
237+
// Set up context with invalid LLM_RESPONSE_FILTER path
238+
Map<String, String> parameters = new HashMap<>();
239+
parameters.put(LLM_RESPONSE_FILTER, "$.invalid.path");
240+
context.setParameters(parameters);
241+
242+
// Create mock response with simple structure
243+
Map<String, Object> responseData = new HashMap<>();
244+
responseData.put("response", "Fallback summary content");
245+
246+
MLTaskResponse mockResponse = createMockMLTaskResponse(responseData);
247+
248+
// Use reflection to access the private method
249+
java.lang.reflect.Method extractMethod = SummarizationManager.class
250+
.getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class);
251+
extractMethod.setAccessible(true);
252+
253+
String result = (String) extractMethod.invoke(manager, mockResponse, context);
254+
255+
// Should fall back to default parsing
256+
Assert.assertEquals("Fallback summary content", result);
257+
}
258+
259+
@Test
260+
public void testExtractSummaryFromResponseWithoutFilter() throws Exception {
261+
Map<String, Object> config = new HashMap<>();
262+
manager.initialize(config);
263+
264+
// Context without LLM_RESPONSE_FILTER
265+
Map<String, String> parameters = new HashMap<>();
266+
context.setParameters(parameters);
267+
268+
// Create mock response with simple structure
269+
Map<String, Object> responseData = new HashMap<>();
270+
responseData.put("response", "Default parsed summary");
271+
272+
MLTaskResponse mockResponse = createMockMLTaskResponse(responseData);
273+
274+
// Use reflection to access the private method
275+
java.lang.reflect.Method extractMethod = SummarizationManager.class
276+
.getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class);
277+
extractMethod.setAccessible(true);
278+
279+
String result = (String) extractMethod.invoke(manager, mockResponse, context);
280+
281+
Assert.assertEquals("Default parsed summary", result);
282+
}
283+
284+
@Test
285+
public void testExtractSummaryFromResponseWithEmptyFilter() throws Exception {
286+
Map<String, Object> config = new HashMap<>();
287+
manager.initialize(config);
288+
289+
// Set up context with empty LLM_RESPONSE_FILTER
290+
Map<String, String> parameters = new HashMap<>();
291+
parameters.put(LLM_RESPONSE_FILTER, "");
292+
context.setParameters(parameters);
293+
294+
// Create mock response
295+
Map<String, Object> responseData = new HashMap<>();
296+
responseData.put("response", "Empty filter fallback");
297+
298+
MLTaskResponse mockResponse = createMockMLTaskResponse(responseData);
299+
300+
// Use reflection to access the private method
301+
java.lang.reflect.Method extractMethod = SummarizationManager.class
302+
.getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class);
303+
extractMethod.setAccessible(true);
304+
305+
String result = (String) extractMethod.invoke(manager, mockResponse, context);
306+
307+
Assert.assertEquals("Empty filter fallback", result);
308+
}
309+
310+
/**
311+
* Helper method to create a mock MLTaskResponse with the given data.
312+
*/
313+
private MLTaskResponse createMockMLTaskResponse(Map<String, Object> responseData) {
314+
ModelTensor tensor = ModelTensor.builder().dataAsMap(responseData).build();
315+
316+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();
317+
318+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
319+
320+
return MLTaskResponse.builder().output(output).build();
321+
}
322+
164323
/**
165324
* Helper method to add tool interactions to the context.
166325
*/

plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ private void executeAgentWithContextManagement(
274274

275275
/**
276276
* Gets the effective context management name for an agent.
277-
* Priority: 1) Runtime parameter from execution request, 2) Agent's stored configuration (set by MLAgentExecutor)
277+
* Priority: 1) Runtime parameter from execution request, 2) Agent's stored configuration, 3) Runtime parameters set by MLAgentExecutor
278278
* This follows the same pattern as MCP connectors.
279279
*
280280
* @param agentInput the agent ML input
@@ -288,7 +288,69 @@ private String getEffectiveContextManagementName(AgentMLInput agentInput) {
288288
return runtimeContextManagementName;
289289
}
290290

291-
// Priority 2: Agent's stored configuration (set by MLAgentExecutor in input parameters)
291+
// Priority 2: Check agent's stored configuration directly
292+
String agentId = agentInput.getAgentId();
293+
if (agentId != null) {
294+
try {
295+
// Use a blocking call to get the agent synchronously
296+
// This is acceptable here since we're in the task execution path
297+
java.util.concurrent.CompletableFuture<String> future = new java.util.concurrent.CompletableFuture<>();
298+
299+
try (
300+
org.opensearch.common.util.concurrent.ThreadContext.StoredContext context = client
301+
.threadPool()
302+
.getThreadContext()
303+
.stashContext()
304+
) {
305+
client
306+
.get(
307+
new org.opensearch.action.get.GetRequest(org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX, agentId),
308+
org.opensearch.core.action.ActionListener.runBefore(org.opensearch.core.action.ActionListener.wrap(response -> {
309+
if (response.isExists()) {
310+
try {
311+
org.opensearch.core.xcontent.XContentParser parser =
312+
org.opensearch.common.xcontent.json.JsonXContent.jsonXContent
313+
.createParser(
314+
null,
315+
org.opensearch.common.xcontent.LoggingDeprecationHandler.INSTANCE,
316+
response.getSourceAsString()
317+
);
318+
org.opensearch.core.xcontent.XContentParserUtils
319+
.ensureExpectedToken(
320+
org.opensearch.core.xcontent.XContentParser.Token.START_OBJECT,
321+
parser.nextToken(),
322+
parser
323+
);
324+
org.opensearch.ml.common.agent.MLAgent mlAgent = org.opensearch.ml.common.agent.MLAgent
325+
.parse(parser);
326+
327+
if (mlAgent.hasContextManagementTemplate()) {
328+
String templateName = mlAgent.getContextManagementTemplateName();
329+
future.complete(templateName);
330+
} else {
331+
future.complete(null);
332+
}
333+
} catch (Exception e) {
334+
future.completeExceptionally(e);
335+
}
336+
} else {
337+
future.complete(null); // Agent not found
338+
}
339+
}, future::completeExceptionally), context::restore)
340+
);
341+
}
342+
343+
// Wait for the result with a timeout
344+
String contextManagementName = future.get(5, java.util.concurrent.TimeUnit.SECONDS);
345+
if (contextManagementName != null && !contextManagementName.trim().isEmpty()) {
346+
return contextManagementName;
347+
}
348+
} catch (Exception e) {
349+
// Continue to fallback methods
350+
}
351+
}
352+
353+
// Priority 3: Agent's runtime parameters (set by MLAgentExecutor in input parameters)
292354
if (agentInput.getInputDataset() instanceof org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) {
293355
org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet dataset =
294356
(org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) agentInput.getInputDataset();
@@ -303,7 +365,6 @@ private String getEffectiveContextManagementName(AgentMLInput agentInput) {
303365
// Handle template references (not processed by MLAgentExecutor)
304366
String agentContextManagementName = dataset.getParameters().get("context_management");
305367
if (agentContextManagementName != null && !agentContextManagementName.trim().isEmpty()) {
306-
log.debug("Using agent-level context management template reference: {}", agentContextManagementName);
307368
return agentContextManagementName;
308369
}
309370
}

0 commit comments

Comments
 (0)