|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Unified callback pipeline system for ADK. |
| 16 | +
|
| 17 | +This module provides a unified way to handle all callback types in ADK, |
| 18 | +eliminating code duplication and improving maintainability. |
| 19 | +
|
| 20 | +Key components: |
| 21 | +- CallbackPipeline: Generic pipeline executor for callbacks |
| 22 | +- normalize_callbacks: Helper to standardize callback inputs |
| 23 | +- CallbackExecutor: Integrates plugin and agent callbacks |
| 24 | +
|
| 25 | +Example: |
| 26 | + >>> # Normalize callbacks |
| 27 | + >>> callbacks = normalize_callbacks(agent.before_model_callback) |
| 28 | + >>> |
| 29 | + >>> # Execute pipeline |
| 30 | + >>> pipeline = CallbackPipeline(callbacks=callbacks) |
| 31 | + >>> result = await pipeline.execute(callback_context, llm_request) |
| 32 | +""" |
| 33 | + |
| 34 | +from __future__ import annotations |
| 35 | + |
| 36 | +import inspect |
| 37 | +from typing import Any |
| 38 | +from typing import Callable |
| 39 | +from typing import Generic |
| 40 | +from typing import Optional |
| 41 | +from typing import TypeVar |
| 42 | +from typing import Union |
| 43 | + |
| 44 | + |
| 45 | +TInput = TypeVar('TInput') |
| 46 | +TOutput = TypeVar('TOutput') |
| 47 | +TCallback = TypeVar('TCallback', bound=Callable) |
| 48 | + |
| 49 | + |
| 50 | +class CallbackPipeline(Generic[TInput, TOutput]): |
| 51 | + """Unified callback execution pipeline. |
| 52 | +
|
| 53 | + This class provides a consistent way to execute callbacks with the following |
| 54 | + features: |
| 55 | + - Automatic sync/async callback handling |
| 56 | + - Early exit on first non-None result |
| 57 | + - Type-safe through generics |
| 58 | + - Minimal performance overhead |
| 59 | +
|
| 60 | + The pipeline executes callbacks in order and returns the first non-None |
| 61 | + result. If all callbacks return None, the pipeline returns None. |
| 62 | +
|
| 63 | + Example: |
| 64 | + >>> async def callback1(ctx, req): |
| 65 | + ... return None # Continue to next callback |
| 66 | + >>> |
| 67 | + >>> async def callback2(ctx, req): |
| 68 | + ... return LlmResponse(...) # Early exit, this is returned |
| 69 | + >>> |
| 70 | + >>> pipeline = CallbackPipeline([callback1, callback2]) |
| 71 | + >>> result = await pipeline.execute(context, request) |
| 72 | + >>> # result is the return value of callback2 |
| 73 | + """ |
| 74 | + |
| 75 | + def __init__( |
| 76 | + self, |
| 77 | + callbacks: Optional[list[Callable]] = None, |
| 78 | + ): |
| 79 | + """Initializes the callback pipeline. |
| 80 | +
|
| 81 | + Args: |
| 82 | + callbacks: List of callback functions. Can be sync or async. |
| 83 | + Callbacks are executed in the order provided. |
| 84 | + """ |
| 85 | + self._callbacks = callbacks or [] |
| 86 | + |
| 87 | + async def execute( |
| 88 | + self, |
| 89 | + *args: Any, |
| 90 | + **kwargs: Any, |
| 91 | + ) -> Optional[TOutput]: |
| 92 | + """Executes the callback pipeline. |
| 93 | +
|
| 94 | + Callbacks are executed in order. The pipeline returns the first non-None |
| 95 | + result (early exit). If all callbacks return None, returns None. |
| 96 | +
|
| 97 | + Both sync and async callbacks are supported automatically. |
| 98 | +
|
| 99 | + Args: |
| 100 | + *args: Positional arguments passed to each callback |
| 101 | + **kwargs: Keyword arguments passed to each callback |
| 102 | +
|
| 103 | + Returns: |
| 104 | + The first non-None result from callbacks, or None if all callbacks |
| 105 | + return None. |
| 106 | +
|
| 107 | + Example: |
| 108 | + >>> result = await pipeline.execute( |
| 109 | + ... callback_context=ctx, |
| 110 | + ... llm_request=request, |
| 111 | + ... ) |
| 112 | + """ |
| 113 | + for callback in self._callbacks: |
| 114 | + result = callback(*args, **kwargs) |
| 115 | + |
| 116 | + # Handle async callbacks |
| 117 | + if inspect.isawaitable(result): |
| 118 | + result = await result |
| 119 | + |
| 120 | + # Early exit: return first non-None result |
| 121 | + if result is not None: |
| 122 | + return result |
| 123 | + |
| 124 | + return None |
| 125 | + |
| 126 | + def add_callback(self, callback: Callable) -> None: |
| 127 | + """Adds a callback to the pipeline. |
| 128 | +
|
| 129 | + Args: |
| 130 | + callback: The callback function to add. Can be sync or async. |
| 131 | + """ |
| 132 | + self._callbacks.append(callback) |
| 133 | + |
| 134 | + def has_callbacks(self) -> bool: |
| 135 | + """Checks if the pipeline has any callbacks. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + True if the pipeline has callbacks, False otherwise. |
| 139 | + """ |
| 140 | + return len(self._callbacks) > 0 |
| 141 | + |
| 142 | + @property |
| 143 | + def callbacks(self) -> list[Callable]: |
| 144 | + """Returns the list of callbacks in the pipeline. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + List of callback functions. |
| 148 | + """ |
| 149 | + return self._callbacks |
| 150 | + |
| 151 | + |
| 152 | +def normalize_callbacks( |
| 153 | + callback: Union[None, Callable, list[Callable]] |
| 154 | +) -> list[Callable]: |
| 155 | + """Normalizes callback input to a list. |
| 156 | +
|
| 157 | + This function replaces all the canonical_*_callbacks properties in |
| 158 | + BaseAgent and LlmAgent by providing a single utility to standardize |
| 159 | + callback inputs. |
| 160 | +
|
| 161 | + Args: |
| 162 | + callback: Can be: |
| 163 | + - None: Returns empty list |
| 164 | + - Single callback: Returns list with one element |
| 165 | + - List of callbacks: Returns the list as-is |
| 166 | +
|
| 167 | + Returns: |
| 168 | + Normalized list of callbacks. |
| 169 | +
|
| 170 | + Example: |
| 171 | + >>> normalize_callbacks(None) |
| 172 | + [] |
| 173 | + >>> normalize_callbacks(my_callback) |
| 174 | + [my_callback] |
| 175 | + >>> normalize_callbacks([cb1, cb2]) |
| 176 | + [cb1, cb2] |
| 177 | +
|
| 178 | + Note: |
| 179 | + This function eliminates 6 duplicate canonical_*_callbacks methods: |
| 180 | + - canonical_before_agent_callbacks |
| 181 | + - canonical_after_agent_callbacks |
| 182 | + - canonical_before_model_callbacks |
| 183 | + - canonical_after_model_callbacks |
| 184 | + - canonical_before_tool_callbacks |
| 185 | + - canonical_after_tool_callbacks |
| 186 | + """ |
| 187 | + if callback is None: |
| 188 | + return [] |
| 189 | + if isinstance(callback, list): |
| 190 | + return callback |
| 191 | + return [callback] |
| 192 | + |
| 193 | + |
| 194 | +class CallbackExecutor: |
| 195 | + """Unified executor for plugin and agent callbacks. |
| 196 | +
|
| 197 | + This class coordinates the execution order of plugin callbacks and agent |
| 198 | + callbacks: |
| 199 | + 1. Execute plugin callback first (higher priority) |
| 200 | + 2. If plugin returns None, execute agent callbacks |
| 201 | + 3. Return first non-None result |
| 202 | +
|
| 203 | + This pattern is used in: |
| 204 | + - Before/after agent callbacks |
| 205 | + - Before/after model callbacks |
| 206 | + - Before/after tool callbacks |
| 207 | + """ |
| 208 | + |
| 209 | + @staticmethod |
| 210 | + async def execute_with_plugins( |
| 211 | + plugin_callback: Callable, |
| 212 | + agent_callbacks: list[Callable], |
| 213 | + *args: Any, |
| 214 | + **kwargs: Any, |
| 215 | + ) -> Optional[Any]: |
| 216 | + """Executes plugin and agent callbacks in order. |
| 217 | +
|
| 218 | + Execution order: |
| 219 | + 1. Plugin callback (priority) |
| 220 | + 2. Agent callbacks (if plugin returns None) |
| 221 | +
|
| 222 | + Args: |
| 223 | + plugin_callback: The plugin callback function to execute first. |
| 224 | + agent_callbacks: List of agent callbacks to execute if plugin returns |
| 225 | + None. |
| 226 | + *args: Positional arguments passed to callbacks |
| 227 | + **kwargs: Keyword arguments passed to callbacks |
| 228 | +
|
| 229 | + Returns: |
| 230 | + First non-None result from plugin or agent callbacks, or None. |
| 231 | +
|
| 232 | + Example: |
| 233 | + >>> result = await CallbackExecutor.execute_with_plugins( |
| 234 | + ... plugin_callback=lambda: plugin_manager.run_before_model_callback( |
| 235 | + ... callback_context=ctx, |
| 236 | + ... llm_request=request, |
| 237 | + ... ), |
| 238 | + ... agent_callbacks=normalize_callbacks(agent.before_model_callback), |
| 239 | + ... callback_context=ctx, |
| 240 | + ... llm_request=request, |
| 241 | + ... ) |
| 242 | + """ |
| 243 | + # Step 1: Execute plugin callback (priority) |
| 244 | + result = plugin_callback(*args, **kwargs) |
| 245 | + if inspect.isawaitable(result): |
| 246 | + result = await result |
| 247 | + |
| 248 | + if result is not None: |
| 249 | + return result |
| 250 | + |
| 251 | + # Step 2: Execute agent callbacks if plugin returned None |
| 252 | + if agent_callbacks: |
| 253 | + pipeline = CallbackPipeline(callbacks=agent_callbacks) |
| 254 | + result = await pipeline.execute(*args, **kwargs) |
| 255 | + |
| 256 | + return result |
| 257 | + |
0 commit comments