diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a1d633bc06..ed49cd8e1a 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -45,6 +45,8 @@ from ..utils.feature_decorator import experimental from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext +from .callback_pipeline import CallbackPipeline +from .callback_pipeline import normalize_callbacks if TYPE_CHECKING: from .invocation_context import InvocationContext @@ -404,30 +406,6 @@ def _create_invocation_context( invocation_context = parent_context.model_copy(update={'agent': self}) return invocation_context - @property - def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.before_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.before_agent_callback: - return [] - if isinstance(self.before_agent_callback, list): - return self.before_agent_callback - return [self.before_agent_callback] - - @property - def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]: - """The resolved self.after_agent_callback field as a list of _SingleAgentCallback. - - This method is only for use by Agent Development Kit. - """ - if not self.after_agent_callback: - return [] - if isinstance(self.after_agent_callback, list): - return self.after_agent_callback - return [self.after_agent_callback] - async def _handle_before_agent_callback( self, ctx: InvocationContext ) -> Optional[Event]: @@ -450,18 +428,12 @@ async def _handle_before_agent_callback( # If no overrides are provided from the plugins, further run the canonical # callbacks. - if ( - not before_agent_callback_content - and self.canonical_before_agent_callbacks - ): - for callback in self.canonical_before_agent_callbacks: - before_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(before_agent_callback_content): - before_agent_callback_content = await before_agent_callback_content - if before_agent_callback_content: - break + callbacks = normalize_callbacks(self.before_agent_callback) + if not before_agent_callback_content and callbacks: + pipeline = CallbackPipeline(callbacks) + before_agent_callback_content = await pipeline.execute( + callback_context=callback_context + ) # Process the override content if exists, and further process the state # change if exists. @@ -510,18 +482,12 @@ async def _handle_after_agent_callback( # If no overrides are provided from the plugins, further run the canonical # callbacks. - if ( - not after_agent_callback_content - and self.canonical_after_agent_callbacks - ): - for callback in self.canonical_after_agent_callbacks: - after_agent_callback_content = callback( - callback_context=callback_context - ) - if inspect.isawaitable(after_agent_callback_content): - after_agent_callback_content = await after_agent_callback_content - if after_agent_callback_content: - break + callbacks = normalize_callbacks(self.after_agent_callback) + if not after_agent_callback_content and callbacks: + pipeline = CallbackPipeline(callbacks) + after_agent_callback_content = await pipeline.execute( + callback_context=callback_context + ) # Process the override content if exists, and further process the state # change if exists. diff --git a/src/google/adk/agents/callback_pipeline.py b/src/google/adk/agents/callback_pipeline.py new file mode 100644 index 0000000000..edd46ac586 --- /dev/null +++ b/src/google/adk/agents/callback_pipeline.py @@ -0,0 +1,189 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified callback pipeline system for ADK. + +This module provides a unified way to handle all callback types in ADK, +eliminating code duplication and improving maintainability. + +Key components: +- CallbackPipeline: Generic pipeline executor for callbacks +- normalize_callbacks: Helper to standardize callback inputs + +Example: + >>> # Normalize callbacks + >>> callbacks = normalize_callbacks(agent.before_model_callback) + >>> + >>> # Execute pipeline + >>> pipeline = CallbackPipeline(callbacks=callbacks) + >>> result = await pipeline.execute(callback_context, llm_request) +""" + +from __future__ import annotations + +import inspect +from typing import Any +from typing import Callable +from typing import Generic +from typing import Optional +from typing import TypeVar +from typing import Union + + +TOutput = TypeVar('TOutput') + + +class CallbackPipeline(Generic[TOutput]): + """Unified callback execution pipeline. + + This class provides a consistent way to execute callbacks with the following + features: + - Automatic sync/async callback handling + - Early exit on first non-None result + - Type-safe through generics + - Minimal performance overhead + + The pipeline executes callbacks in order and returns the first non-None + result. If all callbacks return None, the pipeline returns None. + + Example: + >>> async def callback1(ctx, req): + ... return None # Continue to next callback + >>> + >>> async def callback2(ctx, req): + ... return LlmResponse(...) # Early exit, this is returned + >>> + >>> pipeline = CallbackPipeline([callback1, callback2]) + >>> result = await pipeline.execute(context, request) + >>> # result is the return value of callback2 + """ + + def __init__( + self, + callbacks: Optional[list[Callable]] = None, + ): + """Initializes the callback pipeline. + + Args: + callbacks: List of callback functions. Can be sync or async. + Callbacks are executed in the order provided. + """ + self._callbacks = callbacks or [] + + async def execute( + self, + *args: Any, + **kwargs: Any, + ) -> Optional[TOutput]: + """Executes the callback pipeline. + + Callbacks are executed in order. The pipeline returns the first non-None + result (early exit). If all callbacks return None, returns None. + + Both sync and async callbacks are supported automatically. + + Args: + *args: Positional arguments passed to each callback + **kwargs: Keyword arguments passed to each callback + + Returns: + The first non-None result from callbacks, or None if all callbacks + return None. + + Example: + >>> result = await pipeline.execute( + ... callback_context=ctx, + ... llm_request=request, + ... ) + """ + for callback in self._callbacks: + result = callback(*args, **kwargs) + + # Handle async callbacks + if inspect.isawaitable(result): + result = await result + + # Early exit: return first non-None result + if result is not None: + return result + + return None + + def add_callback(self, callback: Callable) -> None: + """Adds a callback to the pipeline. + + Args: + callback: The callback function to add. Can be sync or async. + """ + self._callbacks.append(callback) + + def has_callbacks(self) -> bool: + """Checks if the pipeline has any callbacks. + + Returns: + True if the pipeline has callbacks, False otherwise. + """ + return bool(self._callbacks) + + @property + def callbacks(self) -> list[Callable]: + """Returns a copy of the list of callbacks in the pipeline. + + Returns: + List of callback functions. + """ + return self._callbacks.copy() + + +def normalize_callbacks( + callback: Union[None, Callable, list[Callable]] +) -> list[Callable]: + """Normalizes callback input to a list. + + This function replaces all the canonical_*_callbacks properties in + BaseAgent and LlmAgent by providing a single utility to standardize + callback inputs. + + Args: + callback: Can be: + - None: Returns empty list + - Single callback: Returns list with one element + - List of callbacks: Returns the list as-is + + Returns: + Normalized list of callbacks. + + Example: + >>> normalize_callbacks(None) + [] + >>> normalize_callbacks(my_callback) + [my_callback] + >>> normalize_callbacks([cb1, cb2]) + [cb1, cb2] + + Note: + This function eliminates 6 duplicate canonical_*_callbacks methods: + - canonical_before_agent_callbacks + - canonical_after_agent_callbacks + - canonical_before_model_callbacks + - canonical_after_model_callbacks + - canonical_before_tool_callbacks + - canonical_after_tool_callbacks + """ + if callback is None: + return [] + if isinstance(callback, list): + return callback + return [callback] + diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 8e13ea8910..53bee38ad7 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -38,6 +38,7 @@ from typing_extensions import override from typing_extensions import TypeAlias +from .callback_pipeline import normalize_callbacks from ..code_executors.base_code_executor import BaseCodeExecutor from ..events.event import Event from ..flows.llm_flows.auto_flow import AutoFlow @@ -565,69 +566,99 @@ async def canonical_tools( def canonical_before_model_callbacks( self, ) -> list[_SingleBeforeModelCallback]: - """The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback. + """Deprecated: Use normalize_callbacks(self.before_model_callback). - This method is only for use by Agent Development Kit. + This property is deprecated and will be removed in a future version. + Use normalize_callbacks() from callback_pipeline module instead. + + Returns: + List of before_model callbacks. """ - if not self.before_model_callback: - return [] - if isinstance(self.before_model_callback, list): - return self.before_model_callback - return [self.before_model_callback] + warnings.warn( + 'canonical_before_model_callbacks is deprecated. ' + 'Use normalize_callbacks(self.before_model_callback) instead.', + DeprecationWarning, + stacklevel=2, + ) + return normalize_callbacks(self.before_model_callback) @property def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]: - """The resolved self.after_model_callback field as a list of _SingleAfterModelCallback. + """Deprecated: Use normalize_callbacks(self.after_model_callback). - This method is only for use by Agent Development Kit. + This property is deprecated and will be removed in a future version. + Use normalize_callbacks() from callback_pipeline module instead. + + Returns: + List of after_model callbacks. """ - if not self.after_model_callback: - return [] - if isinstance(self.after_model_callback, list): - return self.after_model_callback - return [self.after_model_callback] + warnings.warn( + 'canonical_after_model_callbacks is deprecated. ' + 'Use normalize_callbacks(self.after_model_callback) instead.', + DeprecationWarning, + stacklevel=2, + ) + return normalize_callbacks(self.after_model_callback) @property def canonical_before_tool_callbacks( self, ) -> list[BeforeToolCallback]: - """The resolved self.before_tool_callback field as a list of BeforeToolCallback. + """Deprecated: Use normalize_callbacks(self.before_tool_callback). - This method is only for use by Agent Development Kit. + This property is deprecated and will be removed in a future version. + Use normalize_callbacks() from callback_pipeline module instead. + + Returns: + List of before_tool callbacks. """ - if not self.before_tool_callback: - return [] - if isinstance(self.before_tool_callback, list): - return self.before_tool_callback - return [self.before_tool_callback] + warnings.warn( + 'canonical_before_tool_callbacks is deprecated. ' + 'Use normalize_callbacks(self.before_tool_callback) instead.', + DeprecationWarning, + stacklevel=2, + ) + return normalize_callbacks(self.before_tool_callback) @property def canonical_after_tool_callbacks( self, ) -> list[AfterToolCallback]: - """The resolved self.after_tool_callback field as a list of AfterToolCallback. + """Deprecated: Use normalize_callbacks(self.after_tool_callback). - This method is only for use by Agent Development Kit. + This property is deprecated and will be removed in a future version. + Use normalize_callbacks() from callback_pipeline module instead. + + Returns: + List of after_tool callbacks. """ - if not self.after_tool_callback: - return [] - if isinstance(self.after_tool_callback, list): - return self.after_tool_callback - return [self.after_tool_callback] + warnings.warn( + 'canonical_after_tool_callbacks is deprecated. ' + 'Use normalize_callbacks(self.after_tool_callback) instead.', + DeprecationWarning, + stacklevel=2, + ) + return normalize_callbacks(self.after_tool_callback) @property def canonical_on_tool_error_callbacks( self, ) -> list[OnToolErrorCallback]: - """The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback. + """Deprecated: Use normalize_callbacks(self.on_tool_error_callback). - This method is only for use by Agent Development Kit. + This property is deprecated and will be removed in a future version. + Use normalize_callbacks() from callback_pipeline module instead. + + Returns: + List of on_tool_error callbacks. """ - if not self.on_tool_error_callback: - return [] - if isinstance(self.on_tool_error_callback, list): - return self.on_tool_error_callback - return [self.on_tool_error_callback] + warnings.warn( + 'canonical_on_tool_error_callbacks is deprecated. ' + 'Use normalize_callbacks(self.on_tool_error_callback) instead.', + DeprecationWarning, + stacklevel=2, + ) + return normalize_callbacks(self.on_tool_error_callback) @property def _llm_flow(self) -> BaseLlmFlow: diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 644dc55b6c..ada26ccf67 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -32,6 +32,8 @@ from . import functions from ...agents.base_agent import BaseAgent from ...agents.callback_context import CallbackContext +from ...agents.callback_pipeline import CallbackPipeline +from ...agents.callback_pipeline import normalize_callbacks from ...agents.invocation_context import InvocationContext from ...agents.live_request_queue import LiveRequestQueue from ...agents.readonly_context import ReadonlyContext @@ -829,16 +831,13 @@ async def _handle_before_model_callback( # If no overrides are provided from the plugins, further run the canonical # callbacks. - if not agent.canonical_before_model_callbacks: - return - for callback in agent.canonical_before_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_request=llm_request - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return callback_response + callbacks = normalize_callbacks(agent.before_model_callback) + pipeline = CallbackPipeline(callbacks) + callback_response = await pipeline.execute( + callback_context=callback_context, llm_request=llm_request + ) + if callback_response: + return callback_response async def _handle_after_model_callback( self, @@ -886,16 +885,13 @@ async def _maybe_add_grounding_metadata( # If no overrides are provided from the plugins, further run the canonical # callbacks. - if not agent.canonical_after_model_callbacks: - return await _maybe_add_grounding_metadata() - for callback in agent.canonical_after_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_response=llm_response - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) + callbacks = normalize_callbacks(agent.after_model_callback) + pipeline = CallbackPipeline(callbacks) + callback_response = await pipeline.execute( + callback_context=callback_context, llm_response=llm_response + ) + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) return await _maybe_add_grounding_metadata() def _finalize_model_response_event( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 91f8808f5f..0aaad976cb 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -31,6 +31,8 @@ from google.genai import types from ...agents.active_streaming_tool import ActiveStreamingTool +from ...agents.callback_pipeline import CallbackPipeline +from ...agents.callback_pipeline import normalize_callbacks from ...agents.invocation_context import InvocationContext from ...auth.auth_tool import AuthToolArguments from ...events.event import Event @@ -295,17 +297,16 @@ async def _run_on_tool_error_callbacks( if error_response is not None: return error_response - for callback in agent.canonical_on_tool_error_callbacks: - error_response = callback( - tool=tool, - args=tool_args, - tool_context=tool_context, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response + callbacks = normalize_callbacks(agent.on_tool_error_callback) + pipeline = CallbackPipeline(callbacks) + error_response = await pipeline.execute( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if error_response is not None: + return error_response return None @@ -351,14 +352,11 @@ async def _run_with_trace(): # Step 2: If no overrides are provided from the plugins, further run the # canonical callback. if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + callbacks = normalize_callbacks(agent.before_tool_callback) + pipeline = CallbackPipeline(callbacks) + function_response = await pipeline.execute( + tool=tool, args=function_args, tool_context=tool_context + ) # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: @@ -392,17 +390,14 @@ async def _run_with_trace(): # Step 5: If no overrides are provided from the plugins, further run the # canonical after_tool_callbacks. if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + callbacks = normalize_callbacks(agent.after_tool_callback) + pipeline = CallbackPipeline(callbacks) + altered_function_response = await pipeline.execute( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) # Step 6: If alternative response exists from after_tool_callback, use it # instead of the original function response. @@ -524,14 +519,11 @@ async def _run_with_trace(): # Handle before_tool_callbacks - iterate through the canonical callback # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + callbacks = normalize_callbacks(agent.before_tool_callback) + pipeline = CallbackPipeline(callbacks) + function_response = await pipeline.execute( + tool=tool, args=function_args, tool_context=tool_context + ) if function_response is None: function_response = await _process_function_live_helper( @@ -545,17 +537,14 @@ async def _run_with_trace(): # Calls after_tool_callback if it exists. altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + callbacks = normalize_callbacks(agent.after_tool_callback) + pipeline = CallbackPipeline(callbacks) + altered_function_response = await pipeline.execute( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) if altered_function_response is not None: function_response = altered_function_response diff --git a/tests/unittests/agents/test_callback_pipeline.py b/tests/unittests/agents/test_callback_pipeline.py new file mode 100644 index 0000000000..1c89cebdfa --- /dev/null +++ b/tests/unittests/agents/test_callback_pipeline.py @@ -0,0 +1,240 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for callback_pipeline module.""" + +import pytest + +from google.adk.agents.callback_pipeline import CallbackPipeline +from google.adk.agents.callback_pipeline import normalize_callbacks + + +class TestNormalizeCallbacks: + """Tests for normalize_callbacks helper function.""" + + def test_none_input(self): + """None should return empty list.""" + result = normalize_callbacks(None) + assert result == [] + assert isinstance(result, list) + + def test_single_callback(self): + """Single callback should be wrapped in list.""" + + def my_callback(): + return 'result' + + result = normalize_callbacks(my_callback) + assert result == [my_callback] + assert len(result) == 1 + assert callable(result[0]) + + def test_list_input(self): + """List of callbacks should be returned as-is.""" + + def cb1(): + pass + + def cb2(): + pass + + callbacks = [cb1, cb2] + result = normalize_callbacks(callbacks) + assert result == callbacks + assert result is callbacks # Same object + + def test_empty_list_input(self): + """Empty list should be returned as-is.""" + result = normalize_callbacks([]) + assert result == [] + + +class TestCallbackPipeline: + """Tests for CallbackPipeline class.""" + + @pytest.mark.asyncio + async def test_empty_pipeline(self): + """Empty pipeline should return None.""" + pipeline = CallbackPipeline() + result = await pipeline.execute() + assert result is None + + @pytest.mark.asyncio + async def test_single_sync_callback(self): + """Pipeline should execute single sync callback.""" + + def callback(): + return 'result' + + pipeline = CallbackPipeline(callbacks=[callback]) + result = await pipeline.execute() + assert result == 'result' + + @pytest.mark.asyncio + async def test_single_async_callback(self): + """Pipeline should execute single async callback.""" + + async def callback(): + return 'async_result' + + pipeline = CallbackPipeline(callbacks=[callback]) + result = await pipeline.execute() + assert result == 'async_result' + + @pytest.mark.asyncio + async def test_early_exit_on_first_non_none(self): + """Pipeline should exit on first non-None result.""" + call_count = {'count': 0} + + def cb1(): + call_count['count'] += 1 + return None + + def cb2(): + call_count['count'] += 1 + return 'second' + + def cb3(): + call_count['count'] += 1 + raise AssertionError('cb3 should not be called') + + pipeline = CallbackPipeline(callbacks=[cb1, cb2, cb3]) + result = await pipeline.execute() + + assert result == 'second' + assert call_count['count'] == 2 # Only cb1 and cb2 called + + @pytest.mark.asyncio + async def test_all_callbacks_return_none(self): + """Pipeline should return None if all callbacks return None.""" + + def cb1(): + return None + + def cb2(): + return None + + pipeline = CallbackPipeline(callbacks=[cb1, cb2]) + result = await pipeline.execute() + assert result is None + + @pytest.mark.asyncio + async def test_mixed_sync_async_callbacks(self): + """Pipeline should handle mix of sync and async callbacks.""" + + def sync_cb(): + return None + + async def async_cb(): + return 'mixed_result' + + pipeline = CallbackPipeline(callbacks=[sync_cb, async_cb]) + result = await pipeline.execute() + assert result == 'mixed_result' + + @pytest.mark.asyncio + async def test_callback_with_arguments(self): + """Pipeline should pass arguments to callbacks.""" + + def callback(x, y, z=None): + return f'{x}-{y}-{z}' + + pipeline = CallbackPipeline(callbacks=[callback]) + result = await pipeline.execute('a', 'b', z='c') + assert result == 'a-b-c' + + @pytest.mark.asyncio + async def test_callback_with_keyword_arguments(self): + """Pipeline should pass keyword arguments to callbacks.""" + + def callback(*, name, value): + return f'{name}={value}' + + pipeline = CallbackPipeline(callbacks=[callback]) + result = await pipeline.execute(name='test', value=42) + assert result == 'test=42' + + @pytest.mark.asyncio + async def test_add_callback_dynamically(self): + """Should be able to add callbacks dynamically.""" + pipeline = CallbackPipeline() + + def callback(): + return 'added' + + assert not pipeline.has_callbacks() + pipeline.add_callback(callback) + assert pipeline.has_callbacks() + + result = await pipeline.execute() + assert result == 'added' + + def test_has_callbacks(self): + """has_callbacks should return correct value.""" + pipeline = CallbackPipeline() + assert not pipeline.has_callbacks() + + pipeline = CallbackPipeline(callbacks=[lambda: None]) + assert pipeline.has_callbacks() + + def test_callbacks_property(self): + """callbacks property should return the callbacks list.""" + + def cb1(): + pass + + def cb2(): + pass + + callbacks = [cb1, cb2] + pipeline = CallbackPipeline(callbacks=callbacks) + assert pipeline.callbacks == callbacks + + +class TestBackwardCompatibility: + """Tests ensuring backward compatibility with existing code.""" + + def test_normalize_callbacks_matches_canonical_behavior(self): + """normalize_callbacks should match canonical_*_callbacks behavior.""" + + def callback1(): + pass + + def callback2(): + pass + + # Test None case + assert normalize_callbacks(None) == [] + + # Test single callback case + assert normalize_callbacks(callback1) == [callback1] + + # Test list case + callback_list = [callback1, callback2] + assert normalize_callbacks(callback_list) == callback_list + + # This mimics the old canonical_*_callbacks logic: + def old_canonical_callbacks(callback_input): + if not callback_input: + return [] + if isinstance(callback_input, list): + return callback_input + return [callback_input] + + # Verify they produce identical results + for test_input in [None, callback1, callback_list]: + assert normalize_callbacks(test_input) == old_canonical_callbacks( + test_input + ) +