From 384f01266ffc24ca5a453c9ac21851999920effd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=8E=89=E6=B6=B5?= Date: Thu, 30 Oct 2025 18:43:41 +0800 Subject: [PATCH] feat(toolset): add generate_preprocessing_events method to BaseToolset Add new generate_preprocessing_events method to BaseToolset that allows toolsets to generate events (such as authentication requests) during the preprocessing phase before tool discovery occurs. Changes: - Add generate_preprocessing_events method to BaseToolset with default implementation - Integrate method call in base_llm_flow.py _preprocess_async method - Call generate_preprocessing_events before process_llm_request for proper timing - Add comprehensive unit tests covering method invocation and event generation - Maintain full backward compatibility with existing toolsets This provides the foundation for solving MCP Toolset OAuth2 authentication flow issues by allowing toolsets to request user authentication before attempting to discover tools that require authenticated sessions. The method has access to full ToolContext with authentication capabilities and executes at the perfect timing - after request processors but before tool discovery (get_tools). --- .../adk/flows/llm_flows/base_llm_flow.py | 9 ++ src/google/adk/tools/base_toolset.py | 30 ++++ .../flows/llm_flows/test_base_llm_flow.py | 152 ++++++++++++++++++ 3 files changed, 191 insertions(+) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 644dc55b6c..b4eecc47d9 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -476,6 +476,15 @@ async def _preprocess_async( # If it's a toolset, process it first if isinstance(tool_union, BaseToolset): + # Generate preprocessing events (e.g., authentication requests) + async with Aclosing( + tool_union.generate_preprocessing_events( + tool_context=tool_context, llm_request=llm_request + ) + ) as agen: + async for event in agen: + yield event + await tool_union.process_llm_request( tool_context=tool_context, llm_request=llm_request ) diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 201eec9087..e4c9577a14 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -17,6 +17,7 @@ from abc import ABC from abc import abstractmethod import copy +from typing import AsyncGenerator from typing import final from typing import List from typing import Optional @@ -31,6 +32,7 @@ from .base_tool import BaseTool if TYPE_CHECKING: + from ..events.event import Event from ..models.llm_request import LlmRequest from .tool_configs import ToolArgsConfig from .tool_context import ToolContext @@ -204,3 +206,31 @@ async def process_llm_request( llm_request: The outgoing LLM request, mutable this method. """ pass + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + """Generates events during the preprocessing phase. + + This method allows toolsets to generate events (such as authentication + requests) before tool discovery occurs. It has access to the full + ToolContext with authentication capabilities. + + Use cases: + - OAuth2 authentication flows before tool discovery + - User confirmation requests for sensitive toolsets + - Dynamic configuration based on user context + - Pre-flight checks that require user interaction + + Args: + tool_context: The context of the tool with full authentication capabilities. + llm_request: The outgoing LLM request, mutable by this method. + + Yields: + Event: Events for user interaction (e.g., authentication requests). + """ + # Default implementation yields nothing (backward compatibility) + # Subclasses can override to yield authentication or other events + if False: # This ensures the method is an AsyncGenerator + yield # Required for AsyncGenerator type hint + return diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 81ef925a39..9a5848d36d 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -14,6 +14,7 @@ """Unit tests for BaseLlmFlow toolset integration.""" +from typing import AsyncGenerator from unittest import mock from unittest.mock import AsyncMock @@ -26,6 +27,7 @@ from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_toolset import BaseToolset from google.adk.tools.google_search_tool import GoogleSearchTool +from google.adk.tools.tool_context import ToolContext from google.genai import types import pytest @@ -91,6 +93,156 @@ async def close(self): assert mock_toolset.process_llm_request_called +@pytest.mark.asyncio +async def test_preprocess_calls_toolset_generate_preprocessing_events(): + """Test that _preprocess_async calls generate_preprocessing_events on toolsets.""" + + # Create a mock toolset that tracks if generate_preprocessing_events was called + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.generate_preprocessing_events_called = False + self.generated_events = [] + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + self.generate_preprocessing_events_called = True + # Generate a mock authentication event + auth_event = Event( + author='system', + invocation_id='test_invocation', + content=types.Content( + role='model', + parts=[types.Part(text='Mock authentication request')], + ), + ) + self.generated_events.append(auth_event) + yield auth_event + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that generate_preprocessing_events was called on the toolset + assert mock_toolset.generate_preprocessing_events_called + + # Verify that the generated event was yielded + assert len(events) == 1 + assert events[0].author == 'system' + assert events[0].content.parts[0].text == 'Mock authentication request' + + +@pytest.mark.asyncio +async def test_preprocess_calls_both_generate_events_and_process_request(): + """Test that _preprocess_async calls both generate_preprocessing_events and process_llm_request.""" + + # Create a mock toolset that tracks both method calls + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.generate_preprocessing_events_called = False + self.process_llm_request_called = False + self.call_order = [] + + async def generate_preprocessing_events( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + self.generate_preprocessing_events_called = True + self.call_order.append('generate_preprocessing_events') + # Generate a mock event + yield Event( + author='system', + invocation_id='test_invocation', + content=types.Content( + role='model', parts=[types.Part(text='Mock event')] + ), + ) + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + self.process_llm_request_called = True + self.call_order.append('process_llm_request') + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that both methods were called + assert mock_toolset.generate_preprocessing_events_called + assert mock_toolset.process_llm_request_called + + # Verify the correct call order (generate_preprocessing_events first) + assert mock_toolset.call_order == [ + 'generate_preprocessing_events', + 'process_llm_request', + ] + + # Verify that the generated event was yielded + assert len(events) == 1 + assert events[0].author == 'system' + assert events[0].content.parts[0].text == 'Mock event' + + @pytest.mark.asyncio async def test_preprocess_handles_mixed_tools_and_toolsets(): """Test that _preprocess_async properly handles both tools and toolsets."""