Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 14 additions & 48 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from ..utils.feature_decorator import experimental
from .base_agent_config import BaseAgentConfig
from .callback_context import CallbackContext
from .callback_pipeline import CallbackPipeline
from .callback_pipeline import normalize_callbacks

if TYPE_CHECKING:
from .invocation_context import InvocationContext
Expand Down Expand Up @@ -404,30 +406,6 @@ def _create_invocation_context(
invocation_context = parent_context.model_copy(update={'agent': self})
return invocation_context

@property
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.

This method is only for use by Agent Development Kit.
"""
if not self.before_agent_callback:
return []
if isinstance(self.before_agent_callback, list):
return self.before_agent_callback
return [self.before_agent_callback]

@property
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.

This method is only for use by Agent Development Kit.
"""
if not self.after_agent_callback:
return []
if isinstance(self.after_agent_callback, list):
return self.after_agent_callback
return [self.after_agent_callback]

async def _handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
Expand All @@ -450,18 +428,12 @@ async def _handle_before_agent_callback(

# If no overrides are provided from the plugins, further run the canonical
# callbacks.
if (
not before_agent_callback_content
and self.canonical_before_agent_callbacks
):
for callback in self.canonical_before_agent_callbacks:
before_agent_callback_content = callback(
callback_context=callback_context
)
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content:
break
callbacks = normalize_callbacks(self.before_agent_callback)
if not before_agent_callback_content and callbacks:
pipeline = CallbackPipeline(callbacks)
before_agent_callback_content = await pipeline.execute(
callback_context=callback_context
)

# Process the override content if exists, and further process the state
# change if exists.
Expand Down Expand Up @@ -510,18 +482,12 @@ async def _handle_after_agent_callback(

# If no overrides are provided from the plugins, further run the canonical
# callbacks.
if (
not after_agent_callback_content
and self.canonical_after_agent_callbacks
):
for callback in self.canonical_after_agent_callbacks:
after_agent_callback_content = callback(
callback_context=callback_context
)
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content:
break
callbacks = normalize_callbacks(self.after_agent_callback)
if not after_agent_callback_content and callbacks:
pipeline = CallbackPipeline(callbacks)
after_agent_callback_content = await pipeline.execute(
callback_context=callback_context
)

# Process the override content if exists, and further process the state
# change if exists.
Expand Down
189 changes: 189 additions & 0 deletions src/google/adk/agents/callback_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified callback pipeline system for ADK.

This module provides a unified way to handle all callback types in ADK,
eliminating code duplication and improving maintainability.

Key components:
- CallbackPipeline: Generic pipeline executor for callbacks
- normalize_callbacks: Helper to standardize callback inputs

Example:
>>> # Normalize callbacks
>>> callbacks = normalize_callbacks(agent.before_model_callback)
>>>
>>> # Execute pipeline
>>> pipeline = CallbackPipeline(callbacks=callbacks)
>>> result = await pipeline.execute(callback_context, llm_request)
"""

from __future__ import annotations

import inspect
from typing import Any
from typing import Callable
from typing import Generic
from typing import Optional
from typing import TypeVar
from typing import Union


TOutput = TypeVar('TOutput')


class CallbackPipeline(Generic[TOutput]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description highlights the goal of strong type safety. While TOutput provides type safety for return values, the input arguments are typed as *args: Any, **kwargs: Any, which bypasses static type checking for callback arguments.

To enhance type safety and ensure all callbacks within a pipeline share a compatible signature, consider using typing.ParamSpec. This will allow mypy to validate the arguments passed to pipeline.execute() against the expected signature of the callbacks. This change would make the CallbackPipeline more robust and better align with the stated goal of strong type safety.

Here's an example of how you could apply this:

from typing import ParamSpec
# ... other imports

P = ParamSpec('P')
TOutput = TypeVar('TOutput')

class CallbackPipeline(Generic[P, TOutput]):
    def __init__(
        self,
        callbacks: Optional[list[Callable[P, Any]]] = None,
    ):
        self._callbacks = callbacks or []

    async def execute(
        self,
        *args: P.args,
        **kwargs: P.kwargs,
    ) -> Optional[TOutput]:
        # ... implementation remains the same

"""Unified callback execution pipeline.

This class provides a consistent way to execute callbacks with the following
features:
- Automatic sync/async callback handling
- Early exit on first non-None result
- Type-safe through generics
- Minimal performance overhead

The pipeline executes callbacks in order and returns the first non-None
result. If all callbacks return None, the pipeline returns None.

Example:
>>> async def callback1(ctx, req):
... return None # Continue to next callback
>>>
>>> async def callback2(ctx, req):
... return LlmResponse(...) # Early exit, this is returned
>>>
>>> pipeline = CallbackPipeline([callback1, callback2])
>>> result = await pipeline.execute(context, request)
>>> # result is the return value of callback2
"""

def __init__(
self,
callbacks: Optional[list[Callable]] = None,
):
"""Initializes the callback pipeline.

Args:
callbacks: List of callback functions. Can be sync or async.
Callbacks are executed in the order provided.
"""
self._callbacks = callbacks or []

async def execute(
self,
*args: Any,
**kwargs: Any,
) -> Optional[TOutput]:
"""Executes the callback pipeline.

Callbacks are executed in order. The pipeline returns the first non-None
result (early exit). If all callbacks return None, returns None.

Both sync and async callbacks are supported automatically.

Args:
*args: Positional arguments passed to each callback
**kwargs: Keyword arguments passed to each callback

Returns:
The first non-None result from callbacks, or None if all callbacks
return None.

Example:
>>> result = await pipeline.execute(
... callback_context=ctx,
... llm_request=request,
... )
"""
for callback in self._callbacks:
result = callback(*args, **kwargs)

# Handle async callbacks
if inspect.isawaitable(result):
result = await result

# Early exit: return first non-None result
if result is not None:
return result

return None

def add_callback(self, callback: Callable) -> None:
"""Adds a callback to the pipeline.

Args:
callback: The callback function to add. Can be sync or async.
"""
self._callbacks.append(callback)

def has_callbacks(self) -> bool:
"""Checks if the pipeline has any callbacks.

Returns:
True if the pipeline has callbacks, False otherwise.
"""
return bool(self._callbacks)

@property
def callbacks(self) -> list[Callable]:
"""Returns a copy of the list of callbacks in the pipeline.

Returns:
List of callback functions.
"""
return self._callbacks.copy()


def normalize_callbacks(
callback: Union[None, Callable, list[Callable]]
) -> list[Callable]:
"""Normalizes callback input to a list.

This function replaces all the canonical_*_callbacks properties in
BaseAgent and LlmAgent by providing a single utility to standardize
callback inputs.

Args:
callback: Can be:
- None: Returns empty list
- Single callback: Returns list with one element
- List of callbacks: Returns the list as-is

Returns:
Normalized list of callbacks.

Example:
>>> normalize_callbacks(None)
[]
>>> normalize_callbacks(my_callback)
[my_callback]
>>> normalize_callbacks([cb1, cb2])
[cb1, cb2]

Note:
This function eliminates 6 duplicate canonical_*_callbacks methods:
- canonical_before_agent_callbacks
- canonical_after_agent_callbacks
- canonical_before_model_callbacks
- canonical_after_model_callbacks
- canonical_before_tool_callbacks
- canonical_after_tool_callbacks
"""
if callback is None:
return []
if isinstance(callback, list):
return callback
return [callback]

Loading