Skip to content

Commit aa46ff3

Browse files
committed
refactor: introduce unified callback pipeline system
- Add CallbackPipeline generic class for type-safe callback execution - Add normalize_callbacks helper to replace 6 duplicate canonical methods - Add CallbackExecutor for plugin + agent callback integration - Add comprehensive test suite (24 test cases, all passing) This is Phase 1-3 and 6 of the refactoring plan. Seeking feedback before proceeding with full implementation. #non-breaking
1 parent d193c39 commit aa46ff3

File tree

2 files changed

+657
-0
lines changed

2 files changed

+657
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unified callback pipeline system for ADK.
16+
17+
This module provides a unified way to handle all callback types in ADK,
18+
eliminating code duplication and improving maintainability.
19+
20+
Key components:
21+
- CallbackPipeline: Generic pipeline executor for callbacks
22+
- normalize_callbacks: Helper to standardize callback inputs
23+
- CallbackExecutor: Integrates plugin and agent callbacks
24+
25+
Example:
26+
>>> # Normalize callbacks
27+
>>> callbacks = normalize_callbacks(agent.before_model_callback)
28+
>>>
29+
>>> # Execute pipeline
30+
>>> pipeline = CallbackPipeline(callbacks=callbacks)
31+
>>> result = await pipeline.execute(callback_context, llm_request)
32+
"""
33+
34+
from __future__ import annotations
35+
36+
import inspect
37+
from typing import Any
38+
from typing import Callable
39+
from typing import Generic
40+
from typing import Optional
41+
from typing import TypeVar
42+
from typing import Union
43+
44+
45+
TInput = TypeVar('TInput')
46+
TOutput = TypeVar('TOutput')
47+
TCallback = TypeVar('TCallback', bound=Callable)
48+
49+
50+
class CallbackPipeline(Generic[TInput, TOutput]):
51+
"""Unified callback execution pipeline.
52+
53+
This class provides a consistent way to execute callbacks with the following
54+
features:
55+
- Automatic sync/async callback handling
56+
- Early exit on first non-None result
57+
- Type-safe through generics
58+
- Minimal performance overhead
59+
60+
The pipeline executes callbacks in order and returns the first non-None
61+
result. If all callbacks return None, the pipeline returns None.
62+
63+
Example:
64+
>>> async def callback1(ctx, req):
65+
... return None # Continue to next callback
66+
>>>
67+
>>> async def callback2(ctx, req):
68+
... return LlmResponse(...) # Early exit, this is returned
69+
>>>
70+
>>> pipeline = CallbackPipeline([callback1, callback2])
71+
>>> result = await pipeline.execute(context, request)
72+
>>> # result is the return value of callback2
73+
"""
74+
75+
def __init__(
76+
self,
77+
callbacks: Optional[list[Callable]] = None,
78+
):
79+
"""Initializes the callback pipeline.
80+
81+
Args:
82+
callbacks: List of callback functions. Can be sync or async.
83+
Callbacks are executed in the order provided.
84+
"""
85+
self._callbacks = callbacks or []
86+
87+
async def execute(
88+
self,
89+
*args: Any,
90+
**kwargs: Any,
91+
) -> Optional[TOutput]:
92+
"""Executes the callback pipeline.
93+
94+
Callbacks are executed in order. The pipeline returns the first non-None
95+
result (early exit). If all callbacks return None, returns None.
96+
97+
Both sync and async callbacks are supported automatically.
98+
99+
Args:
100+
*args: Positional arguments passed to each callback
101+
**kwargs: Keyword arguments passed to each callback
102+
103+
Returns:
104+
The first non-None result from callbacks, or None if all callbacks
105+
return None.
106+
107+
Example:
108+
>>> result = await pipeline.execute(
109+
... callback_context=ctx,
110+
... llm_request=request,
111+
... )
112+
"""
113+
for callback in self._callbacks:
114+
result = callback(*args, **kwargs)
115+
116+
# Handle async callbacks
117+
if inspect.isawaitable(result):
118+
result = await result
119+
120+
# Early exit: return first non-None result
121+
if result is not None:
122+
return result
123+
124+
return None
125+
126+
def add_callback(self, callback: Callable) -> None:
127+
"""Adds a callback to the pipeline.
128+
129+
Args:
130+
callback: The callback function to add. Can be sync or async.
131+
"""
132+
self._callbacks.append(callback)
133+
134+
def has_callbacks(self) -> bool:
135+
"""Checks if the pipeline has any callbacks.
136+
137+
Returns:
138+
True if the pipeline has callbacks, False otherwise.
139+
"""
140+
return len(self._callbacks) > 0
141+
142+
@property
143+
def callbacks(self) -> list[Callable]:
144+
"""Returns the list of callbacks in the pipeline.
145+
146+
Returns:
147+
List of callback functions.
148+
"""
149+
return self._callbacks
150+
151+
152+
def normalize_callbacks(
153+
callback: Union[None, Callable, list[Callable]]
154+
) -> list[Callable]:
155+
"""Normalizes callback input to a list.
156+
157+
This function replaces all the canonical_*_callbacks properties in
158+
BaseAgent and LlmAgent by providing a single utility to standardize
159+
callback inputs.
160+
161+
Args:
162+
callback: Can be:
163+
- None: Returns empty list
164+
- Single callback: Returns list with one element
165+
- List of callbacks: Returns the list as-is
166+
167+
Returns:
168+
Normalized list of callbacks.
169+
170+
Example:
171+
>>> normalize_callbacks(None)
172+
[]
173+
>>> normalize_callbacks(my_callback)
174+
[my_callback]
175+
>>> normalize_callbacks([cb1, cb2])
176+
[cb1, cb2]
177+
178+
Note:
179+
This function eliminates 6 duplicate canonical_*_callbacks methods:
180+
- canonical_before_agent_callbacks
181+
- canonical_after_agent_callbacks
182+
- canonical_before_model_callbacks
183+
- canonical_after_model_callbacks
184+
- canonical_before_tool_callbacks
185+
- canonical_after_tool_callbacks
186+
"""
187+
if callback is None:
188+
return []
189+
if isinstance(callback, list):
190+
return callback
191+
return [callback]
192+
193+
194+
class CallbackExecutor:
195+
"""Unified executor for plugin and agent callbacks.
196+
197+
This class coordinates the execution order of plugin callbacks and agent
198+
callbacks:
199+
1. Execute plugin callback first (higher priority)
200+
2. If plugin returns None, execute agent callbacks
201+
3. Return first non-None result
202+
203+
This pattern is used in:
204+
- Before/after agent callbacks
205+
- Before/after model callbacks
206+
- Before/after tool callbacks
207+
"""
208+
209+
@staticmethod
210+
async def execute_with_plugins(
211+
plugin_callback: Callable,
212+
agent_callbacks: list[Callable],
213+
*args: Any,
214+
**kwargs: Any,
215+
) -> Optional[Any]:
216+
"""Executes plugin and agent callbacks in order.
217+
218+
Execution order:
219+
1. Plugin callback (priority)
220+
2. Agent callbacks (if plugin returns None)
221+
222+
Args:
223+
plugin_callback: The plugin callback function to execute first.
224+
agent_callbacks: List of agent callbacks to execute if plugin returns
225+
None.
226+
*args: Positional arguments passed to callbacks
227+
**kwargs: Keyword arguments passed to callbacks
228+
229+
Returns:
230+
First non-None result from plugin or agent callbacks, or None.
231+
232+
Example:
233+
>>> result = await CallbackExecutor.execute_with_plugins(
234+
... plugin_callback=lambda: plugin_manager.run_before_model_callback(
235+
... callback_context=ctx,
236+
... llm_request=request,
237+
... ),
238+
... agent_callbacks=normalize_callbacks(agent.before_model_callback),
239+
... callback_context=ctx,
240+
... llm_request=request,
241+
... )
242+
"""
243+
# Step 1: Execute plugin callback (priority)
244+
result = plugin_callback(*args, **kwargs)
245+
if inspect.isawaitable(result):
246+
result = await result
247+
248+
if result is not None:
249+
return result
250+
251+
# Step 2: Execute agent callbacks if plugin returned None
252+
if agent_callbacks:
253+
pipeline = CallbackPipeline(callbacks=agent_callbacks)
254+
result = await pipeline.execute(*args, **kwargs)
255+
256+
return result
257+

0 commit comments

Comments
 (0)