Skip to content

Commit 96a0749

Browse files
committed
refactor: use CallbackPipeline consistently in all callback execution sites
Address bot feedback (round 4) by replacing all manual callback iterations with CallbackPipeline.execute() for consistency and maintainability. Changes (9 locations): 1. base_agent.py: Use CallbackPipeline for before/after agent callbacks 2. callback_pipeline.py: Optimize single plugin callback execution 3. base_llm_flow.py: Use CallbackPipeline for before/after model callbacks 4. functions.py: Use CallbackPipeline for all tool callbacks (async + live) Impact: - Eliminates remaining manual callback iteration logic (~40 lines) - Achieves 100% consistency in callback execution - All sync/async handling and early exit logic centralized - Tests: 24/24 passing - Lint: 9.57/10 (improved from 9.49/10) #non-breaking
1 parent aaf3c19 commit 96a0749

File tree

4 files changed

+42
-57
lines changed

4 files changed

+42
-57
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..utils.feature_decorator import experimental
4646
from .base_agent_config import BaseAgentConfig
4747
from .callback_context import CallbackContext
48+
from .callback_pipeline import CallbackPipeline
4849
from .callback_pipeline import normalize_callbacks
4950

5051
if TYPE_CHECKING:
@@ -429,14 +430,10 @@ async def _handle_before_agent_callback(
429430
# callbacks.
430431
callbacks = normalize_callbacks(self.before_agent_callback)
431432
if not before_agent_callback_content and callbacks:
432-
for callback in callbacks:
433-
before_agent_callback_content = callback(
434-
callback_context=callback_context
435-
)
436-
if inspect.isawaitable(before_agent_callback_content):
437-
before_agent_callback_content = await before_agent_callback_content
438-
if before_agent_callback_content:
439-
break
433+
pipeline = CallbackPipeline(callbacks)
434+
before_agent_callback_content = await pipeline.execute(
435+
callback_context=callback_context
436+
)
440437

441438
# Process the override content if exists, and further process the state
442439
# change if exists.
@@ -487,14 +484,10 @@ async def _handle_after_agent_callback(
487484
# callbacks.
488485
callbacks = normalize_callbacks(self.after_agent_callback)
489486
if not after_agent_callback_content and callbacks:
490-
for callback in callbacks:
491-
after_agent_callback_content = callback(
492-
callback_context=callback_context
493-
)
494-
if inspect.isawaitable(after_agent_callback_content):
495-
after_agent_callback_content = await after_agent_callback_content
496-
if after_agent_callback_content:
497-
break
487+
pipeline = CallbackPipeline(callbacks)
488+
after_agent_callback_content = await pipeline.execute(
489+
callback_context=callback_context
490+
)
498491

499492
# Process the override content if exists, and further process the state
500493
# change if exists.

src/google/adk/agents/callback_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ async def execute_with_plugins(
238238
... )
239239
"""
240240
# Step 1: Execute plugin callback (priority)
241-
result = await CallbackPipeline([plugin_callback]).execute(*args, **kwargs)
241+
result = plugin_callback(*args, **kwargs)
242+
if inspect.isawaitable(result):
243+
result = await result
242244
if result is not None:
243245
return result
244246

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from . import functions
3333
from ...agents.base_agent import BaseAgent
3434
from ...agents.callback_context import CallbackContext
35+
from ...agents.callback_pipeline import CallbackPipeline
3536
from ...agents.callback_pipeline import normalize_callbacks
3637
from ...agents.invocation_context import InvocationContext
3738
from ...agents.live_request_queue import LiveRequestQueue
@@ -833,14 +834,12 @@ async def _handle_before_model_callback(
833834
callbacks = normalize_callbacks(agent.before_model_callback)
834835
if not callbacks:
835836
return
836-
for callback in callbacks:
837-
callback_response = callback(
838-
callback_context=callback_context, llm_request=llm_request
839-
)
840-
if inspect.isawaitable(callback_response):
841-
callback_response = await callback_response
842-
if callback_response:
843-
return callback_response
837+
pipeline = CallbackPipeline(callbacks)
838+
callback_response = await pipeline.execute(
839+
callback_context=callback_context, llm_request=llm_request
840+
)
841+
if callback_response:
842+
return callback_response
844843

845844
async def _handle_after_model_callback(
846845
self,
@@ -891,14 +890,12 @@ async def _maybe_add_grounding_metadata(
891890
callbacks = normalize_callbacks(agent.after_model_callback)
892891
if not callbacks:
893892
return await _maybe_add_grounding_metadata()
894-
for callback in callbacks:
895-
callback_response = callback(
896-
callback_context=callback_context, llm_response=llm_response
897-
)
898-
if inspect.isawaitable(callback_response):
899-
callback_response = await callback_response
900-
if callback_response:
901-
return await _maybe_add_grounding_metadata(callback_response)
893+
pipeline = CallbackPipeline(callbacks)
894+
callback_response = await pipeline.execute(
895+
callback_context=callback_context, llm_response=llm_response
896+
)
897+
if callback_response:
898+
return await _maybe_add_grounding_metadata(callback_response)
902899
return await _maybe_add_grounding_metadata()
903900

904901
def _finalize_model_response_event(

src/google/adk/flows/llm_flows/functions.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.genai import types
3232

3333
from ...agents.active_streaming_tool import ActiveStreamingTool
34+
from ...agents.callback_pipeline import CallbackPipeline
3435
from ...agents.callback_pipeline import normalize_callbacks
3536
from ...agents.invocation_context import InvocationContext
3637
from ...auth.auth_tool import AuthToolArguments
@@ -352,14 +353,12 @@ async def _run_with_trace():
352353
# Step 2: If no overrides are provided from the plugins, further run the
353354
# canonical callback.
354355
if function_response is None:
355-
for callback in normalize_callbacks(agent.before_tool_callback):
356-
function_response = callback(
356+
callbacks = normalize_callbacks(agent.before_tool_callback)
357+
if callbacks:
358+
pipeline = CallbackPipeline(callbacks)
359+
function_response = await pipeline.execute(
357360
tool=tool, args=function_args, tool_context=tool_context
358361
)
359-
if inspect.isawaitable(function_response):
360-
function_response = await function_response
361-
if function_response:
362-
break
363362

364363
# Step 3: Otherwise, proceed calling the tool normally.
365364
if function_response is None:
@@ -393,17 +392,15 @@ async def _run_with_trace():
393392
# Step 5: If no overrides are provided from the plugins, further run the
394393
# canonical after_tool_callbacks.
395394
if altered_function_response is None:
396-
for callback in normalize_callbacks(agent.after_tool_callback):
397-
altered_function_response = callback(
395+
callbacks = normalize_callbacks(agent.after_tool_callback)
396+
if callbacks:
397+
pipeline = CallbackPipeline(callbacks)
398+
altered_function_response = await pipeline.execute(
398399
tool=tool,
399400
args=function_args,
400401
tool_context=tool_context,
401402
tool_response=function_response,
402403
)
403-
if inspect.isawaitable(altered_function_response):
404-
altered_function_response = await altered_function_response
405-
if altered_function_response:
406-
break
407404

408405
# Step 6: If alternative response exists from after_tool_callback, use it
409406
# instead of the original function response.
@@ -525,14 +522,12 @@ async def _run_with_trace():
525522

526523
# Handle before_tool_callbacks - iterate through the canonical callback
527524
# list
528-
for callback in normalize_callbacks(agent.before_tool_callback):
529-
function_response = callback(
525+
callbacks = normalize_callbacks(agent.before_tool_callback)
526+
if callbacks:
527+
pipeline = CallbackPipeline(callbacks)
528+
function_response = await pipeline.execute(
530529
tool=tool, args=function_args, tool_context=tool_context
531530
)
532-
if inspect.isawaitable(function_response):
533-
function_response = await function_response
534-
if function_response:
535-
break
536531

537532
if function_response is None:
538533
function_response = await _process_function_live_helper(
@@ -546,17 +541,15 @@ async def _run_with_trace():
546541

547542
# Calls after_tool_callback if it exists.
548543
altered_function_response = None
549-
for callback in normalize_callbacks(agent.after_tool_callback):
550-
altered_function_response = callback(
544+
callbacks = normalize_callbacks(agent.after_tool_callback)
545+
if callbacks:
546+
pipeline = CallbackPipeline(callbacks)
547+
altered_function_response = await pipeline.execute(
551548
tool=tool,
552549
args=function_args,
553550
tool_context=tool_context,
554551
tool_response=function_response,
555552
)
556-
if inspect.isawaitable(altered_function_response):
557-
altered_function_response = await altered_function_response
558-
if altered_function_response:
559-
break
560553

561554
if altered_function_response is not None:
562555
function_response = altered_function_response

0 commit comments

Comments
 (0)