Skip to content

Commit 116b26c

Browse files
RKestcopybara-github
authored andcommitted
feat: add plugin for returning GenAI Parts from tools into the model request
Added to mitigate #3064 Co-authored-by: Max Ind <maxind@google.com> PiperOrigin-RevId: 830135940
1 parent e218254 commit 116b26c

File tree

4 files changed

+300
-0
lines changed

4 files changed

+300
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk.agents import LlmAgent
16+
from google.adk.apps.app import App
17+
from google.adk.plugins.multimodal_tool_results_plugin import MultimodalToolResultsPlugin
18+
from google.genai import types
19+
20+
APP_NAME = "multimodal_tool_results"
21+
USER_ID = "test_user"
22+
23+
24+
def get_image():
25+
return [types.Part.from_uri(file_uri="gs://replace_with_your_image_uri")]
26+
27+
28+
root_agent = LlmAgent(
29+
name="image_describing_agent",
30+
description="image describing agent",
31+
instruction="""Whatever the user says, get the image using the get_image tool, and describe it.""",
32+
model="gemini-2.0-flash",
33+
tools=[get_image],
34+
)
35+
36+
37+
app = App(
38+
name=APP_NAME,
39+
root_agent=root_agent,
40+
plugins=[MultimodalToolResultsPlugin()],
41+
)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from typing import Optional
19+
20+
from google.genai import types
21+
22+
from ..agents.callback_context import CallbackContext
23+
from ..models.llm_request import LlmRequest
24+
from ..models.llm_response import LlmResponse
25+
from ..tools.base_tool import BaseTool
26+
from ..tools.tool_context import ToolContext
27+
from .base_plugin import BasePlugin
28+
29+
PARTS_RETURNED_BY_TOOLS_ID = "temp:PARTS_RETURNED_BY_TOOLS_ID"
30+
31+
32+
class MultimodalToolResultsPlugin(BasePlugin):
33+
"""A plugin that modifies function tool responses to support returning list of parts directly.
34+
35+
Should be removed in favor of directly supporting FunctionResponsePart when these
36+
are supported outside of computer use tool.
37+
For context see: https://github.com/google/adk-python/issues/3064#issuecomment-3463067459
38+
"""
39+
40+
def __init__(self, name: str = "multimodal_tool_results_plugin"):
41+
"""Initialize the multimodal tool results plugin.
42+
43+
Args:
44+
name: The name of the plugin instance.
45+
"""
46+
super().__init__(name)
47+
48+
async def after_tool_callback(
49+
self,
50+
*,
51+
tool: BaseTool,
52+
tool_args: dict[str, Any],
53+
tool_context: ToolContext,
54+
result: dict,
55+
) -> Optional[dict]:
56+
"""Saves parts returned by the tool in ToolContext.
57+
58+
Later these are passed to LLM's context as-is.
59+
No-op if tool doesn't return list[google.genai.types.Part] or google.genai.types.Part.
60+
"""
61+
62+
if not (
63+
isinstance(result, types.Part)
64+
or isinstance(result, list)
65+
and result
66+
and isinstance(result[0], types.Part)
67+
):
68+
return result
69+
70+
parts = [result] if isinstance(result, types.Part) else result[:]
71+
72+
if PARTS_RETURNED_BY_TOOLS_ID in tool_context.state:
73+
tool_context.state[PARTS_RETURNED_BY_TOOLS_ID] += parts
74+
else:
75+
tool_context.state[PARTS_RETURNED_BY_TOOLS_ID] = parts
76+
77+
return None
78+
79+
async def before_model_callback(
80+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
81+
) -> Optional[LlmResponse]:
82+
"""Attach saved list[google.genai.types.Part] returned by the tool to llm_request."""
83+
84+
if saved_parts := callback_context.state.get(
85+
PARTS_RETURNED_BY_TOOLS_ID, None
86+
):
87+
llm_request.contents[-1].parts += saved_parts
88+
callback_context.state.update({PARTS_RETURNED_BY_TOOLS_ID: []})
89+
90+
return None
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from unittest.mock import Mock
19+
20+
from google.adk.agents.base_agent import BaseAgent
21+
from google.adk.agents.callback_context import CallbackContext
22+
from google.adk.models.llm_request import LlmRequest
23+
from google.adk.plugins.multimodal_tool_results_plugin import MultimodalToolResultsPlugin
24+
from google.adk.plugins.multimodal_tool_results_plugin import PARTS_RETURNED_BY_TOOLS_ID
25+
from google.adk.tools.base_tool import BaseTool
26+
from google.adk.tools.tool_context import ToolContext
27+
from google.genai import types
28+
import pytest
29+
30+
from .. import testing_utils
31+
32+
33+
@pytest.fixture
34+
def plugin() -> MultimodalToolResultsPlugin:
35+
"""Create a default plugin instance for testing."""
36+
return MultimodalToolResultsPlugin()
37+
38+
39+
@pytest.fixture
40+
def mock_tool() -> MockTool:
41+
"""Create a mock tool for testing."""
42+
return Mock(spec=BaseTool)
43+
44+
45+
@pytest.fixture
46+
async def tool_context() -> ToolContext:
47+
"""Create a mock tool context."""
48+
return ToolContext(
49+
invocation_context=await testing_utils.create_invocation_context(
50+
agent=Mock(spec=BaseAgent)
51+
)
52+
)
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_tool_returning_parts_are_added_to_llm_request(
57+
plugin: MultimodalToolResultsPlugin,
58+
mock_tool: MockTool,
59+
tool_context: ToolContext,
60+
):
61+
"""Test that parts returned by a tool are present in the llm_request later."""
62+
parts = [types.Part(text="part1"), types.Part(text="part2")]
63+
64+
result = await plugin.after_tool_callback(
65+
tool=mock_tool,
66+
tool_args={},
67+
tool_context=tool_context,
68+
result=parts,
69+
)
70+
71+
assert result == None
72+
assert PARTS_RETURNED_BY_TOOLS_ID in tool_context.state
73+
assert tool_context.state[PARTS_RETURNED_BY_TOOLS_ID] == parts
74+
75+
callback_context = Mock(spec=CallbackContext)
76+
callback_context.state = tool_context.state
77+
llm_request = LlmRequest(contents=[types.Content(parts=[])])
78+
79+
await plugin.before_model_callback(
80+
callback_context=callback_context, llm_request=llm_request
81+
)
82+
83+
assert llm_request.contents[-1].parts == parts
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_tool_returning_non_list_of_parts_is_unchanged(
88+
plugin: MultimodalToolResultsPlugin,
89+
mock_tool: MockTool,
90+
tool_context: ToolContext,
91+
):
92+
"""Test where tool returning non list of parts, has this result unchanged."""
93+
original_result = {"some": "data"}
94+
95+
result = await plugin.after_tool_callback(
96+
tool=mock_tool,
97+
tool_args={},
98+
tool_context=tool_context,
99+
result=original_result,
100+
)
101+
102+
assert result == original_result
103+
assert PARTS_RETURNED_BY_TOOLS_ID not in tool_context.state
104+
105+
callback_context = Mock(spec=CallbackContext)
106+
callback_context.state = tool_context.state
107+
llm_request = LlmRequest(
108+
contents=[types.Content(parts=[types.Part(text="original")])]
109+
)
110+
original_parts = list(llm_request.contents[-1].parts)
111+
112+
await plugin.before_model_callback(
113+
callback_context=callback_context, llm_request=llm_request
114+
)
115+
116+
assert llm_request.contents[-1].parts == original_parts
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_multiple_tools_returning_parts_are_accumulated(
121+
plugin: ToolReturningGenAiPartsPlugin,
122+
mock_tool: MockTool,
123+
tool_context: ToolContext,
124+
):
125+
"""Test that parts from multiple tool calls are accumulated."""
126+
parts1 = [types.Part(text="part1")]
127+
parts2 = [types.Part(text="part2")]
128+
129+
await plugin.after_tool_callback(
130+
tool=mock_tool,
131+
tool_args={},
132+
tool_context=tool_context,
133+
result=parts1,
134+
)
135+
136+
await plugin.after_tool_callback(
137+
tool=mock_tool,
138+
tool_args={},
139+
tool_context=tool_context,
140+
result=parts2,
141+
)
142+
143+
assert PARTS_RETURNED_BY_TOOLS_ID in tool_context.state
144+
assert tool_context.state[PARTS_RETURNED_BY_TOOLS_ID] == parts1 + parts2
145+
146+
callback_context = Mock(spec=CallbackContext)
147+
callback_context.state = tool_context.state
148+
llm_request = LlmRequest(contents=[types.Content(parts=[])])
149+
150+
await plugin.before_model_callback(
151+
callback_context=callback_context, llm_request=llm_request
152+
)
153+
154+
assert llm_request.contents[-1].parts == parts1 + parts2

0 commit comments

Comments
 (0)