From 02a4ff5f6351fc00fd95d8b86599fa5a102ecfb4 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 19 Nov 2025 16:18:18 -0800 Subject: [PATCH 1/2] Fix model id parsing for QueryPlanningTool Signed-off-by: Owais Kazi --- .../engine/algorithms/agent/AgentUtils.java | 5 +++ .../ml/engine/tools/QueryPlanningTool.java | 5 +++ .../engine/tools/QueryPlanningToolTests.java | 38 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..e4a2f0fe10 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -124,6 +124,7 @@ public class AgentUtils { public static final String TOOL_CALL_ID_PATH = "tool_calls.id_path"; private static final String NAME = "name"; private static final String DESCRIPTION = "description"; + public static final String AGENT_LLM_MODEL_ID = "agent_llm_model_id"; public static final String TOOLS = "_tools"; public static final String TOOL_TEMPLATE = "tool_template"; @@ -857,6 +858,10 @@ public static void createTools( if (toolSpecs == null) { return; } + // Add agent's model_id for tools that may need it like QPT + if (mlAgent.getLlm() != null && mlAgent.getLlm().getModelId() != null) { + params.put(AGENT_LLM_MODEL_ID, mlAgent.getLlm().getModelId()); + } for (MLToolSpec toolSpec : toolSpecs) { Map toolParams = buildToolParameters(params, toolSpec, mlAgent.getTenantId()); Tool tool = createTool(toolFactories, toolParams, toolSpec); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java index 26bc77a053..7d2890800d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java @@ -8,6 +8,7 @@ import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.AGENT_LLM_MODEL_ID; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DEFAULT_DATETIME_FORMAT; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY; @@ -420,6 +421,10 @@ public void init(Client client) { @Override public QueryPlanningTool create(Map params) { + // Use agent's Agent model_id if tool doesn't have its own model_id + if (!params.containsKey(MODEL_ID_FIELD) && params.containsKey(AGENT_LLM_MODEL_ID)) { + params.put(MODEL_ID_FIELD, params.get(AGENT_LLM_MODEL_ID)); + } MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java index 27fafefa43..0bf5262216 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java @@ -20,6 +20,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.AGENT_LLM_MODEL_ID; import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY; import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT; import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION; @@ -1497,4 +1498,41 @@ public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() { assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result); } + @Test + public void testFactoryCreate_UsesAgentLlmModelIdAsFallback() { + Map params = new HashMap<>(); + params.put(AGENT_LLM_MODEL_ID, "agent_model_123"); + // No model_id specified - should use agent_llm_model_id + + Tool tool = QueryPlanningTool.Factory.getInstance().create(params); + assertNotNull(tool); + assertEquals(QueryPlanningTool.TYPE, tool.getName()); + + assertTrue(params.containsKey(MODEL_ID_FIELD)); + assertEquals("agent_model_123", params.get(MODEL_ID_FIELD)); + } + + @Test + public void testFactoryCreate_ToolModelIdTakesPrecedence() { + Map params = new HashMap<>(); + params.put(MODEL_ID_FIELD, "tool_model_456"); + params.put(AGENT_LLM_MODEL_ID, "agent_model_123"); + + Tool tool = QueryPlanningTool.Factory.getInstance().create(params); + assertNotNull(tool); + assertEquals(QueryPlanningTool.TYPE, tool.getName()); + + assertEquals("tool_model_456", params.get(MODEL_ID_FIELD)); + } + + @Test + public void testFactoryCreate_NoModelIdProvided() { + Map params = new HashMap<>(); + + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> { QueryPlanningTool.Factory.getInstance().create(params); } + ); + assertEquals("Model ID can't be null or empty", exception.getMessage()); + } } From a6b0633126802360ec5f113e85ecab00300c72fd Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 19 Nov 2025 16:50:13 -0800 Subject: [PATCH 2/2] Fixed tests Signed-off-by: Owais Kazi --- .../ml/engine/algorithms/agent/MLChatAgentRunnerTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..c63db9df4f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input"));