Skip to content

Commit 777dba3

Browse files
davidkl97Jacksunwei
authored andcommitted
feat(tools): Add an option to disallow propagating runner plugins to AgentTool runner
Merge #2779 Fixes #2780 ### testing plan not available as is doesn't introduce new functionality Co-authored-by: Wei Sun (Jack) <weisun@google.com> COPYBARA_INTEGRATE_REVIEW=#2779 from davidkl97:feature/agent-tool-plugins a602c80 PiperOrigin-RevId: 835366974
1 parent 2247a45 commit 777dba3

File tree

2 files changed

+133
-3
lines changed

2 files changed

+133
-3
lines changed

src/google/adk/tools/agent_tool.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,22 @@ class AgentTool(BaseTool):
4545
Attributes:
4646
agent: The agent to wrap.
4747
skip_summarization: Whether to skip summarization of the agent output.
48+
include_plugins: Whether to propagate plugins from the parent runner context
49+
to the agent's runner. When True (default), the agent will inherit all
50+
plugins from its parent. Set to False to run the agent with an isolated
51+
plugin environment.
4852
"""
4953

50-
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
54+
def __init__(
55+
self,
56+
agent: BaseAgent,
57+
skip_summarization: bool = False,
58+
*,
59+
include_plugins: bool = True,
60+
):
5161
self.agent = agent
5262
self.skip_summarization: bool = skip_summarization
63+
self.include_plugins = include_plugins
5364

5465
super().__init__(name=agent.name, description=agent.description)
5566

@@ -130,14 +141,19 @@ async def run_async(
130141
invocation_context.app_name if invocation_context else None
131142
)
132143
child_app_name = parent_app_name or self.agent.name
144+
plugins = (
145+
tool_context._invocation_context.plugin_manager.plugins
146+
if self.include_plugins
147+
else None
148+
)
133149
runner = Runner(
134150
app_name=child_app_name,
135151
agent=self.agent,
136152
artifact_service=ForwardingArtifactService(tool_context),
137153
session_service=InMemorySessionService(),
138154
memory_service=InMemoryMemoryService(),
139155
credential_service=tool_context._invocation_context.credential_service,
140-
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
156+
plugins=plugins,
141157
)
142158

143159
state_dict = {
@@ -192,7 +208,9 @@ def from_config(
192208
agent_tool_config.agent, config_abs_path
193209
)
194210
return cls(
195-
agent=agent, skip_summarization=agent_tool_config.skip_summarization
211+
agent=agent,
212+
skip_summarization=agent_tool_config.skip_summarization,
213+
include_plugins=agent_tool_config.include_plugins,
196214
)
197215

198216

@@ -204,3 +222,6 @@ class AgentToolConfig(BaseToolConfig):
204222

205223
skip_summarization: bool = False
206224
"""Whether to skip summarization of the agent output."""
225+
226+
include_plugins: bool = True
227+
"""Whether to include plugins from parent runner context."""

tests/unittests/tools/test_agent_tool.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,112 @@ class CustomInput(BaseModel):
570570
# Should have string response schema for VERTEX_AI when no output_schema
571571
assert declaration.response is not None
572572
assert declaration.response.type == types.Type.STRING
573+
574+
575+
def test_include_plugins_default_true():
576+
"""Test that plugins are propagated by default (include_plugins=True)."""
577+
578+
# Create a test plugin that tracks callbacks
579+
class TrackingPlugin(BasePlugin):
580+
581+
def __init__(self, name: str):
582+
super().__init__(name)
583+
self.before_agent_calls = 0
584+
585+
async def before_agent_callback(self, **kwargs):
586+
self.before_agent_calls += 1
587+
588+
tracking_plugin = TrackingPlugin(name='tracking')
589+
590+
mock_model = testing_utils.MockModel.create(
591+
responses=[function_call_no_schema, 'response1', 'response2']
592+
)
593+
594+
tool_agent = Agent(
595+
name='tool_agent',
596+
model=mock_model,
597+
)
598+
599+
root_agent = Agent(
600+
name='root_agent',
601+
model=mock_model,
602+
tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True
603+
)
604+
605+
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
606+
runner.run('test1')
607+
608+
# Plugin should be called for both root_agent and tool_agent
609+
assert tracking_plugin.before_agent_calls == 2
610+
611+
612+
def test_include_plugins_explicit_true():
613+
"""Test that plugins are propagated when include_plugins=True."""
614+
615+
class TrackingPlugin(BasePlugin):
616+
617+
def __init__(self, name: str):
618+
super().__init__(name)
619+
self.before_agent_calls = 0
620+
621+
async def before_agent_callback(self, **kwargs):
622+
self.before_agent_calls += 1
623+
624+
tracking_plugin = TrackingPlugin(name='tracking')
625+
626+
mock_model = testing_utils.MockModel.create(
627+
responses=[function_call_no_schema, 'response1', 'response2']
628+
)
629+
630+
tool_agent = Agent(
631+
name='tool_agent',
632+
model=mock_model,
633+
)
634+
635+
root_agent = Agent(
636+
name='root_agent',
637+
model=mock_model,
638+
tools=[AgentTool(agent=tool_agent, include_plugins=True)],
639+
)
640+
641+
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
642+
runner.run('test1')
643+
644+
# Plugin should be called for both root_agent and tool_agent
645+
assert tracking_plugin.before_agent_calls == 2
646+
647+
648+
def test_include_plugins_false():
649+
"""Test that plugins are NOT propagated when include_plugins=False."""
650+
651+
class TrackingPlugin(BasePlugin):
652+
653+
def __init__(self, name: str):
654+
super().__init__(name)
655+
self.before_agent_calls = 0
656+
657+
async def before_agent_callback(self, **kwargs):
658+
self.before_agent_calls += 1
659+
660+
tracking_plugin = TrackingPlugin(name='tracking')
661+
662+
mock_model = testing_utils.MockModel.create(
663+
responses=[function_call_no_schema, 'response1', 'response2']
664+
)
665+
666+
tool_agent = Agent(
667+
name='tool_agent',
668+
model=mock_model,
669+
)
670+
671+
root_agent = Agent(
672+
name='root_agent',
673+
model=mock_model,
674+
tools=[AgentTool(agent=tool_agent, include_plugins=False)],
675+
)
676+
677+
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
678+
runner.run('test1')
679+
680+
# Plugin should only be called for root_agent, not tool_agent
681+
assert tracking_plugin.before_agent_calls == 1

0 commit comments

Comments
 (0)