Skip to content

Commit 9ab17f2

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Yield the long running tool response before pausing execution
PiperOrigin-RevId: 825056377
1 parent 86f0155 commit 9ab17f2

File tree

5 files changed

+146
-50
lines changed

5 files changed

+146
-50
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,16 +430,24 @@ async def _run_async_impl(
430430
yield self._create_agent_state_event(ctx)
431431
return
432432

433+
should_pause = False
433434
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
434435
async for event in agen:
435436
self.__maybe_save_output_to_state(event)
436437
yield event
437438
if ctx.should_pause_invocation(event):
438-
return
439+
# Do not pause immediately, wait until the long running tool call is
440+
# executed.
441+
should_pause = True
442+
if should_pause:
443+
return
439444

440445
if ctx.is_resumable:
441446
events = ctx._get_events(current_invocation=True, current_branch=True)
442-
if events and ctx.should_pause_invocation(events[-1]):
447+
if events and (
448+
ctx.should_pause_invocation(events[-1])
449+
or ctx.should_pause_invocation(events[-2])
450+
):
443451
return
444452
# Only yield an end state if the last event is no longer a long running
445453
# tool call.

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,30 @@ async def _run_one_step_async(
383383
events = invocation_context._get_events(
384384
current_invocation=True, current_branch=True
385385
)
386+
387+
# Long running tool calls should have been handled before this point.
388+
# If there are still long running tool calls, it means the agent is paused
389+
# before, and its branch hasn't been resumed yet.
390+
if (
391+
invocation_context.is_resumable
392+
and events
393+
and len(events) > 1
394+
# TODO: here we are using the last 2 events to decide whether to pause
395+
# the invocation. But this is just being optmisitic, we should find a
396+
# way to pause when the long running tool call is followed by more than
397+
# one text responses.
398+
and (
399+
invocation_context.should_pause_invocation(events[-1])
400+
or invocation_context.should_pause_invocation(events[-2])
401+
)
402+
):
403+
return
404+
386405
if (
387406
invocation_context.is_resumable
388407
and events
389408
and events[-1].get_function_calls()
390409
):
391-
# Long running tool calls should have been handled before this point.
392-
# If there are still long running tool calls, it means the agent is paused
393-
# before, and its branch hasn't been resumed yet.
394-
if invocation_context.should_pause_invocation(events[-1]):
395-
return
396410
model_response_event = events[-1]
397411
async with Aclosing(
398412
self._postprocess_handle_function_calls_async(

src/google/adk/tools/function_tool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def run_async(
205205
' ToolConfirmation payload.'
206206
),
207207
)
208+
tool_context.actions.skip_summarization = True
208209
return {
209210
'error': (
210211
'This tool call requires confirmation, please approve or'

tests/unittests/runners/test_pause_invocation.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _transfer_call_part(agent_name: str) -> Part:
4444

4545

4646
def test_tool() -> str:
47-
return ""
47+
return "result"
4848

4949

5050
class _TestingAgent(BaseAgent):
@@ -126,6 +126,12 @@ def test_pause_on_long_running_function_call(
126126
"""Tests that a single LlmAgent pauses on long running function call."""
127127
assert testing_utils.simplify_resumable_app_events(runner.run("test")) == [
128128
("root_agent", Part.from_function_call(name="test_tool", args={})),
129+
(
130+
"root_agent",
131+
Part.from_function_response(
132+
name="test_tool", response={"result": "result"}
133+
),
134+
),
129135
]
130136

131137

@@ -168,6 +174,12 @@ def test_pause_first_agent_on_long_running_function_call(
168174
),
169175
),
170176
("sub_agent_1", Part.from_function_call(name="test_tool", args={})),
177+
(
178+
"sub_agent_1",
179+
Part.from_function_response(
180+
name="test_tool", response={"result": "result"}
181+
),
182+
),
171183
]
172184

173185
@pytest.mark.asyncio
@@ -195,7 +207,7 @@ def test_pause_second_agent_on_long_running_function_call(
195207
(
196208
"sub_agent_1",
197209
Part.from_function_response(
198-
name="test_tool", response={"result": ""}
210+
name="test_tool", response={"result": "result"}
199211
),
200212
),
201213
("sub_agent_1", "model response after tool call"),
@@ -207,6 +219,12 @@ def test_pause_second_agent_on_long_running_function_call(
207219
),
208220
),
209221
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
222+
(
223+
"sub_agent_2",
224+
Part.from_function_response(
225+
name="test_tool", response={"result": "result"}
226+
),
227+
),
210228
]
211229

212230

@@ -384,6 +402,12 @@ def test_pause_on_long_running_function_call(
384402
),
385403
),
386404
("sub_agent_2", Part.from_function_call(name="test_tool", args={})),
405+
(
406+
"sub_agent_2",
407+
Part.from_function_response(
408+
name="test_tool", response={"result": "result"}
409+
),
410+
),
387411
]
388412

389413

@@ -435,6 +459,12 @@ def test_pause_on_long_running_function_call(
435459
("sub_llm_agent_1", _transfer_call_part("sub_llm_agent_2")),
436460
("sub_llm_agent_1", _TRANSFER_RESPONSE_PART),
437461
("sub_llm_agent_2", Part.from_function_call(name="test_tool", args={})),
462+
(
463+
"sub_llm_agent_2",
464+
Part.from_function_response(
465+
name="test_tool", response={"result": "result"}
466+
),
467+
),
438468
]
439469

440470

@@ -489,4 +519,10 @@ def test_pause_on_long_running_function_call(
489519
("sub_llm_agent_2", _transfer_call_part("root_agent")),
490520
("sub_llm_agent_2", _TRANSFER_RESPONSE_PART),
491521
("root_agent", Part.from_function_call(name="test_tool", args={})),
522+
(
523+
"root_agent",
524+
Part.from_function_response(
525+
name="test_tool", response={"result": "result"}
526+
),
527+
),
492528
]

0 commit comments

Comments
 (0)