Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ public class AgentUtils {
public static final String TOOL_CALL_ID_PATH = "tool_calls.id_path";
private static final String NAME = "name";
private static final String DESCRIPTION = "description";
public static final String AGENT_LLM_MODEL_ID = "agent_llm_model_id";

public static final String TOOLS = "_tools";
public static final String TOOL_TEMPLATE = "tool_template";
Expand Down Expand Up @@ -857,6 +858,10 @@ public static void createTools(
if (toolSpecs == null) {
return;
}
// Add agent's model_id for tools that may need it like QPT
if (mlAgent.getLlm() != null && mlAgent.getLlm().getModelId() != null) {
params.put(AGENT_LLM_MODEL_ID, mlAgent.getLlm().getModelId());
}
for (MLToolSpec toolSpec : toolSpecs) {
Map<String, String> toolParams = buildToolParameters(params, toolSpec, mlAgent.getTenantId());
Tool tool = createTool(toolFactories, toolParams, toolSpec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.AGENT_LLM_MODEL_ID;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DEFAULT_DATETIME_FORMAT;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime;
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
Expand Down Expand Up @@ -420,6 +421,10 @@ public void init(Client client) {

@Override
public QueryPlanningTool create(Map<String, Object> params) {
// Use agent's Agent model_id if tool doesn't have its own model_id
if (!params.containsKey(MODEL_ID_FIELD) && params.containsKey(AGENT_LLM_MODEL_ID)) {
params.put(MODEL_ID_FIELD, params.get(AGENT_LLM_MODEL_ID));
}
Comment on lines +424 to +427
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When creating the Query Planner tool, during the registration of the agent, we are adding the agent's LLM Model ID as the model ID for QPT. But I believe this approach is a much cleaner way of doing things.

Can we also clean up this code in register Agent?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rithin-pullela-aws earlier, you had a PR that check if it's query planning tool then you passed over the model id to query planning tool and that's specific for queryPlanningTool.I wanted something more generic to tools and I like this PR's approach to handle model id for all tool execution.

Can you try find out that PR and then we can revert that commit and take this one.

Copy link
Member Author

@owaiskazi19 owaiskazi19 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the PR but it has some other changes as well. I will leave that to @rithin-pullela-aws


MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ public void testToolParameters() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(15, ((Map) argumentCaptor.getValue()).size());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
Expand Down Expand Up @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
assertEquals("raw input", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
Expand Down Expand Up @@ -804,7 +804,7 @@ public void testToolConfig() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
// The value of input should be "config_value".
assertEquals("config_value", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down Expand Up @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() {
// Verify the size of parameters passed in the tool run method.
ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class);
verify(firstTool).run((Map<String, String>) argumentCaptor.capture(), any());
assertEquals(16, ((Map) argumentCaptor.getValue()).size());
assertEquals(17, ((Map) argumentCaptor.getValue()).size());
// The value of input should be replaced with the value associated with the key "key2" of the first tool.
assertEquals("value2", ((Map<?, ?>) argumentCaptor.getValue()).get("input"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.AGENT_LLM_MODEL_ID;
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
Expand Down Expand Up @@ -1497,4 +1498,41 @@ public void testQueryPlanningTool_WithMockedMLModelTool_EndToEnd() {
assertEquals("{\"query\":{\"match\":{\"title\":\"test\"}}}", result);
}

@Test
public void testFactoryCreate_UsesAgentLlmModelIdAsFallback() {
Map<String, Object> params = new HashMap<>();
params.put(AGENT_LLM_MODEL_ID, "agent_model_123");
// No model_id specified - should use agent_llm_model_id

Tool tool = QueryPlanningTool.Factory.getInstance().create(params);
assertNotNull(tool);
assertEquals(QueryPlanningTool.TYPE, tool.getName());

assertTrue(params.containsKey(MODEL_ID_FIELD));
assertEquals("agent_model_123", params.get(MODEL_ID_FIELD));
}

@Test
public void testFactoryCreate_ToolModelIdTakesPrecedence() {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ID_FIELD, "tool_model_456");
params.put(AGENT_LLM_MODEL_ID, "agent_model_123");

Tool tool = QueryPlanningTool.Factory.getInstance().create(params);
assertNotNull(tool);
assertEquals(QueryPlanningTool.TYPE, tool.getName());

assertEquals("tool_model_456", params.get(MODEL_ID_FIELD));
}

@Test
public void testFactoryCreate_NoModelIdProvided() {
Map<String, Object> params = new HashMap<>();

Exception exception = assertThrows(
IllegalArgumentException.class,
() -> { QueryPlanningTool.Factory.getInstance().create(params); }
);
assertEquals("Model ID can't be null or empty", exception.getMessage());
}
}
Loading