From 5a7076bfbd01c415fee1c2ec2316c005da9d973a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 22 Jul 2025 14:22:17 -0400 Subject: [PATCH 01/41] Don't re-run workflows on un/approvals (#516) These were necessary when we had conditional running but we switched to needing to approve all workflows for non-maintainers, so we no longer need these. Co-authored-by: Mackenzie Zastrow --- .github/workflows/pr-and-push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 2b2d026f4..b558943dd 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -3,7 +3,7 @@ name: Pull Request and Push Action on: pull_request: # Safer than pull_request_target for untrusted code branches: [ main ] - types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] + types: [opened, synchronize, reopened, ready_for_review] push: branches: [ main ] # Also run on direct pushes to main concurrency: From 9aba0189abf43136a9c3eb477ee5257f735730c9 Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Tue, 22 Jul 2025 21:49:29 +0200 Subject: [PATCH 02/41] Fixing some typos in various texts (#487) --- .../conversation_manager/conversation_manager.py | 2 +- src/strands/multiagent/a2a/executor.py | 2 +- src/strands/session/repository_session_manager.py | 14 +++++++------- src/strands/types/session.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 8756a1022..2c1ee7847 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -36,7 +36,7 @@ def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]] Args: state: Previous state of the conversation manager Returns: - Optional list of messages to prepend to the agents messages. By defualt returns None. + Optional list of messages to prepend to the agents messages. By default returns None. """ if state.get("__name__") != self.__class__.__name__: raise ValueError("Invalid conversation manager state.") diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 00eb4764f..d65c64aff 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -4,7 +4,7 @@ to be used as an executor in the A2A protocol. It handles the execution of agent requests and the conversion of Strands Agent streamed responses to A2A events. -The A2A AgentExecutor ensures clients recieve responses for synchronous and +The A2A AgentExecutor ensures clients receive responses for synchronous and streamed requests to the A2AServer. """ diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 487335ac9..18a6ac474 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -32,7 +32,7 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa Args: session_id: ID to use for the session. A new session with this id will be created if it does - not exist in the reposiory yet + not exist in the repository yet session_repository: Underlying session repository to use to store the sessions state. **kwargs: Additional keyword arguments for future extensibility. @@ -133,15 +133,15 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent.state = AgentState(session_agent.state) # Restore the conversation manager to its previous state, and get the optional prepend messages - prepend_messsages = agent.conversation_manager.restore_from_session( + prepend_messages = agent.conversation_manager.restore_from_session( session_agent.conversation_manager_state ) - if prepend_messsages is None: - prepend_messsages = [] + if prepend_messages is None: + prepend_messages = [] # List the messages currently in the session, using an offset of the messages previously removed - # by the converstaion manager. + # by the conversation manager. session_messages = self.session_repository.list_messages( session_id=self.session_id, agent_id=agent.agent_id, @@ -150,5 +150,5 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: if len(session_messages) > 0: self._latest_agent_message[agent.agent_id] = session_messages[-1] - # Resore the agents messages array including the optional prepend messages - agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages] + # Restore the agents messages array including the optional prepend messages + agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 259ab1171..e51816f74 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -125,7 +125,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": - """Initialize a SessionAgent from a dictionary, ignoring keys that are not calss parameters.""" + """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters.""" return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) def to_dict(self) -> dict[str, Any]: @@ -144,7 +144,7 @@ class Session: @classmethod def from_dict(cls, env: dict[str, Any]) -> "Session": - """Initialize a Session from a dictionary, ignoring keys that are not calss parameters.""" + """Initialize a Session from a dictionary, ignoring keys that are not class parameters.""" return cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) def to_dict(self) -> dict[str, Any]: From 040ba21cdfeb5dfbcdbb6e76ec227356a4429329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Tue, 22 Jul 2025 15:52:35 -0400 Subject: [PATCH 03/41] docs(readme): add hot reloading documentation for load_tools_from_directory (#517) - Add new section showcasing Agent(load_tools_from_directory=True) functionality - Document automatic tool loading and reloading from ./tools/ directory - Include practical code example for developers - Improve discoverability of this development feature --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 58c647f8d..62ed54d47 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,17 @@ agent = Agent(tools=[word_count]) response = agent("How many words are in this sentence?") ``` +**Hot Reloading from Directory:** +Enable automatic tool loading and reloading from the `./tools/` directory: + +```python +from strands import Agent + +# Agent will watch ./tools/ directory for changes +agent = Agent(load_tools_from_directory=True) +response = agent("Use any tools you find in the tools directory") +``` + ### MCP Support Seamlessly integrate Model Context Protocol (MCP) servers: From 022ec556d7eed2de935deb8293e86f8263056af5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 22 Jul 2025 16:19:15 -0400 Subject: [PATCH 04/41] ci: enable integ tests for anthropic, cohere, mistral, openai, writer (#510) --- tests_integ/conftest.py | 52 +++++++++++++++++++ tests_integ/models/providers.py | 4 +- .../{conformance.py => test_conformance.py} | 4 +- tests_integ/models/test_model_anthropic.py | 13 +++-- tests_integ/models/test_model_cohere.py | 2 +- 5 files changed, 67 insertions(+), 8 deletions(-) rename tests_integ/models/{conformance.py => test_conformance.py} (81%) diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index f83f0e299..61c2bf9a1 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -1,5 +1,17 @@ +import json +import logging +import os + +import boto3 import pytest +logger = logging.getLogger(__name__) + + +def pytest_sessionstart(session): + _load_api_keys_from_secrets_manager() + + ## Data @@ -28,3 +40,43 @@ async def alist(items): return [item async for item in items] return alist + + +## Models + + +def _load_api_keys_from_secrets_manager(): + """Load API keys as environment variables from AWS Secrets Manager.""" + session = boto3.session.Session() + client = session.client(service_name="secretsmanager") + if "STRANDS_TEST_API_KEYS_SECRET_NAME" in os.environ: + try: + secret_name = os.getenv("STRANDS_TEST_API_KEYS_SECRET_NAME") + response = client.get_secret_value(SecretId=secret_name) + + if "SecretString" in response: + secret = json.loads(response["SecretString"]) + for key, value in secret.items(): + os.environ[f"{key.upper()}_API_KEY"] = str(value) + + except Exception as e: + logger.warning("Error retrieving secret", e) + + """ + Validate that required environment variables are set when running in GitHub Actions. + This prevents tests from being unintentionally skipped due to missing credentials. + """ + if os.environ.get("GITHUB_ACTIONS") != "true": + logger.warning("Tests running outside GitHub Actions, skipping required provider validation") + return + + required_providers = { + "ANTHROPIC_API_KEY", + "COHERE_API_KEY", + "MISTRAL_API_KEY", + "OPENAI_API_KEY", + "WRITER_API_KEY", + } + for provider in required_providers: + if provider not in os.environ or not os.environ[provider]: + raise ValueError(f"Missing required environment variables for {provider}") diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 543f58480..d2ac148d3 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -72,11 +72,11 @@ def __init__(self): bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) cohere = ProviderInfo( id="cohere", - environment_variable="CO_API_KEY", + environment_variable="COHERE_API_KEY", factory=lambda: OpenAIModel( client_args={ "base_url": "https://api.cohere.com/compatibility/v1", - "api_key": os.getenv("CO_API_KEY"), + "api_key": os.getenv("COHERE_API_KEY"), }, model_id="command-a-03-2025", params={"stream_options": None}, diff --git a/tests_integ/models/conformance.py b/tests_integ/models/test_conformance.py similarity index 81% rename from tests_integ/models/conformance.py rename to tests_integ/models/test_conformance.py index 262e41e42..d9875bc07 100644 --- a/tests_integ/models/conformance.py +++ b/tests_integ/models/test_conformance.py @@ -1,6 +1,6 @@ import pytest -from strands.types.models import Model +from strands.models import Model from tests_integ.models.providers import ProviderInfo, all_providers @@ -9,7 +9,7 @@ def get_models(): pytest.param( provider_info, id=provider_info.id, # Adds the provider name to the test name - marks=[provider_info.mark], # ignores tests that don't have the requirements + marks=provider_info.mark, # ignores tests that don't have the requirements ) for provider_info in all_providers ] diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 2ee5e7f23..62a95d06d 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -6,10 +6,17 @@ import strands from strands import Agent from strands.models.anthropic import AnthropicModel -from tests_integ.models import providers -# these tests only run if we have the anthropic api key -pytestmark = providers.anthropic.mark +""" +These tests only run if we have the anthropic api key + +Because of infrequent burst usage, Anthropic tests are unreliable, failing tests with 529s. +{'type': 'error', 'error': {'details': None, 'type': 'overloaded_error', 'message': 'Overloaded'}} +https://docs.anthropic.com/en/api/errors#http-errors +""" +pytestmark = pytest.skip( + "Because of infrequent burst usage, Anthropic tests are unreliable, failing with 529s", allow_module_level=True +) @pytest.fixture diff --git a/tests_integ/models/test_model_cohere.py b/tests_integ/models/test_model_cohere.py index 996b0f326..33fb1a8c6 100644 --- a/tests_integ/models/test_model_cohere.py +++ b/tests_integ/models/test_model_cohere.py @@ -16,7 +16,7 @@ def model(): return OpenAIModel( client_args={ "base_url": "https://api.cohere.com/compatibility/v1", - "api_key": os.getenv("CO_API_KEY"), + "api_key": os.getenv("COHERE_API_KEY"), }, model_id="command-a-03-2025", params={"stream_options": None}, From e597e07f06665292c4207270f41eb37cc45fd645 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:26:30 -0400 Subject: [PATCH 05/41] Automatically flatten nested tool collections (#508) Fixes issue #50 Customers naturally want to pass nested collections of tools - the above issue has gathered enough data points proving that. --- src/strands/tools/registry.py | 11 +++++++++-- tests/strands/agent/test_agent.py | 19 +++++++++++++++++++ tests/strands/tools/test_registry.py | 27 +++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 9d835d28e..fd395ae77 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -11,7 +11,7 @@ from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional from typing_extensions import TypedDict, cast @@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]: """ tool_names = [] - for tool in tools: + def add_tool(tool: Any) -> None: # Case 1: String file path if isinstance(tool, str): # Extract tool name from path @@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]: elif isinstance(tool, AgentTool): self.register_tool(tool) tool_names.append(tool.tool_name) + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) + for a_tool in tools: + add_tool(a_tool) + return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6471a09a..4e310dace 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id(): assert agent.model.config["model_id"] == "nonsense" +def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Nested structure: [tool_decorated, [tool_module, [tool_imported]]] + agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]]) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + +def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Deeply nested structure + nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]] + agent = Agent(tools=nested_tools) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + def test_agent__call__( mock_model, system_prompt, diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ebcba3fb1..66494c987 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -93,3 +93,30 @@ def tool_function_4(d): assert len(tools) == 2 assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools) + + +def test_process_tools_flattens_lists_and_tuples_and_sets(): + def function() -> str: + return "done" + + tool_a = tool(name="tool_a")(function) + tool_b = tool(name="tool_b")(function) + tool_c = tool(name="tool_c")(function) + tool_d = tool(name="tool_d")(function) + tool_e = tool(name="tool_e")(function) + tool_f = tool(name="tool_f")(function) + + registry = ToolRegistry() + + all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]] + + tru_tool_names = sorted(registry.process_tools(all_tools)) + exp_tool_names = [ + "tool_a", + "tool_b", + "tool_c", + "tool_d", + "tool_e", + "tool_f", + ] + assert tru_tool_names == exp_tool_names From 4f4e5efd6730fd05ae4382d5ab1715e7b363be6c Mon Sep 17 00:00:00 2001 From: Jeremiah Date: Wed, 23 Jul 2025 13:44:47 -0400 Subject: [PATCH 06/41] feat(a2a): support mounts for containerized deployments (#524) * feat(a2a): support mounts for containerized deployments * feat(a2a): escape hatch for load balancers which strip paths * feat(a2a): formatting --------- Co-authored-by: jer --- src/strands/multiagent/a2a/server.py | 75 +++- .../session/repository_session_manager.py | 4 +- tests/strands/multiagent/a2a/test_server.py | 343 ++++++++++++++++++ 3 files changed, 412 insertions(+), 10 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index de891499d..fa7b6b887 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -6,6 +6,7 @@ import logging from typing import Any, Literal +from urllib.parse import urlparse import uvicorn from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication @@ -31,6 +32,8 @@ def __init__( # AgentCard host: str = "0.0.0.0", port: int = 9000, + http_url: str | None = None, + serve_at_root: bool = False, version: str = "0.0.1", skills: list[AgentSkill] | None = None, ): @@ -40,13 +43,34 @@ def __init__( agent: The Strands Agent to wrap with A2A compatibility. host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0". port: The port to bind the A2A server to. Defaults to 9000. + http_url: The public HTTP URL where this agent will be accessible. If provided, + this overrides the generated URL from host/port and enables automatic + path-based mounting for load balancer scenarios. + Example: "http://my-alb.amazonaws.com/agent1" + serve_at_root: If True, forces the server to serve at root path regardless of + http_url path component. Use this when your load balancer strips path prefixes. + Defaults to False. version: The version of the agent. Defaults to "0.0.1". skills: The list of capabilities or functions the agent can perform. """ self.host = host self.port = port - self.http_url = f"http://{self.host}:{self.port}/" self.version = version + + if http_url: + # Parse the provided URL to extract components for mounting + self.public_base_url, self.mount_path = self._parse_public_url(http_url) + self.http_url = http_url.rstrip("/") + "/" + + # Override mount path if serve_at_root is requested + if serve_at_root: + self.mount_path = "" + else: + # Fall back to constructing the URL from host and port + self.public_base_url = f"http://{host}:{port}" + self.http_url = f"{self.public_base_url}/" + self.mount_path = "" + self.strands_agent = agent self.name = self.strands_agent.name self.description = self.strands_agent.description @@ -58,6 +82,25 @@ def __init__( self._agent_skills = skills logger.info("Strands' integration with A2A is experimental. Be aware of frequent breaking changes.") + def _parse_public_url(self, url: str) -> tuple[str, str]: + """Parse the public URL into base URL and mount path components. + + Args: + url: The full public URL (e.g., "http://my-alb.amazonaws.com/agent1") + + Returns: + tuple: (base_url, mount_path) where base_url is the scheme+netloc + and mount_path is the path component + + Example: + _parse_public_url("http://my-alb.amazonaws.com/agent1") + Returns: ("http://my-alb.amazonaws.com", "/agent1") + """ + parsed = urlparse(url.rstrip("/")) + base_url = f"{parsed.scheme}://{parsed.netloc}" + mount_path = parsed.path if parsed.path != "/" else "" + return base_url, mount_path + @property def public_agent_card(self) -> AgentCard: """Get the public AgentCard for this agent. @@ -119,24 +162,42 @@ def agent_skills(self, skills: list[AgentSkill]) -> None: def to_starlette_app(self) -> Starlette: """Create a Starlette application for serving this agent via HTTP. - This method creates a Starlette application that can be used to serve - the agent via HTTP using the A2A protocol. + Automatically handles path-based mounting if a mount path was derived + from the http_url parameter. Returns: Starlette: A Starlette application configured to serve this agent. """ - return A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + if self.mount_path: + # Create parent app and mount the A2A app at the specified path + parent_app = Starlette() + parent_app.mount(self.mount_path, a2a_app) + logger.info("Mounting A2A server at path: %s", self.mount_path) + return parent_app + + return a2a_app def to_fastapi_app(self) -> FastAPI: """Create a FastAPI application for serving this agent via HTTP. - This method creates a FastAPI application that can be used to serve - the agent via HTTP using the A2A protocol. + Automatically handles path-based mounting if a mount path was derived + from the http_url parameter. Returns: FastAPI: A FastAPI application configured to serve this agent. """ - return A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + + if self.mount_path: + # Create parent app and mount the A2A app at the specified path + parent_app = FastAPI() + parent_app.mount(self.mount_path, a2a_app) + logger.info("Mounting A2A server at path: %s", self.mount_path) + return parent_app + + return a2a_app def serve( self, diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 18a6ac474..75058b251 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -133,9 +133,7 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent.state = AgentState(session_agent.state) # Restore the conversation manager to its previous state, and get the optional prepend messages - prepend_messages = agent.conversation_manager.restore_from_session( - session_agent.conversation_manager_state - ) + prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state) if prepend_messages is None: prepend_messages = [] diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index 74f470741..fc76b5f1d 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -509,3 +509,346 @@ def test_serve_handles_general_exception(mock_run, mock_strands_agent, caplog): assert "Strands A2A server encountered exception" in caplog.text assert "Strands A2A server has shutdown" in caplog.text + + +def test_initialization_with_http_url_no_path(mock_strands_agent): + """Test initialization with http_url containing no path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer( + mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com", skills=[] + ) + + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://my-alb.amazonaws.com/" + assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "" + + +def test_initialization_with_http_url_with_path(mock_strands_agent): + """Test initialization with http_url containing a path for mounting.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer( + mock_strands_agent, host="0.0.0.0", port=8080, http_url="http://my-alb.amazonaws.com/agent1", skills=[] + ) + + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 8080 + assert a2a_agent.http_url == "http://my-alb.amazonaws.com/agent1/" + assert a2a_agent.public_base_url == "http://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "/agent1" + + +def test_initialization_with_https_url(mock_strands_agent): + """Test initialization with HTTPS URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/secure-agent", skills=[]) + + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/secure-agent/" + assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com" + assert a2a_agent.mount_path == "/secure-agent" + + +def test_initialization_with_http_url_with_port(mock_strands_agent): + """Test initialization with http_url containing explicit port.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://my-server.com:8080/api/agent", skills=[]) + + assert a2a_agent.http_url == "http://my-server.com:8080/api/agent/" + assert a2a_agent.public_base_url == "http://my-server.com:8080" + assert a2a_agent.mount_path == "/api/agent" + + +def test_parse_public_url_method(mock_strands_agent): + """Test the _parse_public_url method directly.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Test various URL formats + base_url, mount_path = a2a_agent._parse_public_url("http://example.com/path") + assert base_url == "http://example.com" + assert mount_path == "/path" + + base_url, mount_path = a2a_agent._parse_public_url("https://example.com:443/deep/path") + assert base_url == "https://example.com:443" + assert mount_path == "/deep/path" + + base_url, mount_path = a2a_agent._parse_public_url("http://example.com/") + assert base_url == "http://example.com" + assert mount_path == "" + + base_url, mount_path = a2a_agent._parse_public_url("http://example.com") + assert base_url == "http://example.com" + assert mount_path == "" + + +def test_public_agent_card_with_http_url(mock_strands_agent): + """Test that public_agent_card uses the http_url when provided.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="https://my-alb.amazonaws.com/agent1", skills=[]) + + card = a2a_agent.public_agent_card + + assert isinstance(card, AgentCard) + assert card.url == "https://my-alb.amazonaws.com/agent1/" + assert card.name == "Test Agent" + assert card.description == "A test agent for unit testing" + + +def test_to_starlette_app_with_mounting(mock_strands_agent): + """Test that to_starlette_app creates mounted app when mount_path exists.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_starlette_app_without_mounting(mock_strands_agent): + """Test that to_starlette_app creates regular app when no mount_path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[]) + + app = a2a_agent.to_starlette_app() + + assert isinstance(app, Starlette) + + +def test_to_fastapi_app_with_mounting(mock_strands_agent): + """Test that to_fastapi_app creates mounted app when mount_path exists.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +def test_to_fastapi_app_without_mounting(mock_strands_agent): + """Test that to_fastapi_app creates regular app when no mount_path.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com", skills=[]) + + app = a2a_agent.to_fastapi_app() + + assert isinstance(app, FastAPI) + + +def test_backwards_compatibility_without_http_url(mock_strands_agent): + """Test that the old behavior is preserved when http_url is not provided.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=9000, skills=[]) + + # Should behave exactly like before + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 9000 + assert a2a_agent.http_url == "http://localhost:9000/" + assert a2a_agent.public_base_url == "http://localhost:9000" + assert a2a_agent.mount_path == "" + + # Agent card should use the traditional URL + card = a2a_agent.public_agent_card + assert card.url == "http://localhost:9000/" + + +def test_mount_path_logging(mock_strands_agent, caplog): + """Test that mounting logs the correct message.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, http_url="http://example.com/test-agent", skills=[]) + + # Test Starlette app mounting logs + caplog.clear() + a2a_agent.to_starlette_app() + assert "Mounting A2A server at path: /test-agent" in caplog.text + + # Test FastAPI app mounting logs + caplog.clear() + a2a_agent.to_fastapi_app() + assert "Mounting A2A server at path: /test-agent" in caplog.text + + +def test_http_url_trailing_slash_handling(mock_strands_agent): + """Test that trailing slashes in http_url are handled correctly.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Test with trailing slash + a2a_agent1 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1/", skills=[]) + + # Test without trailing slash + a2a_agent2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent1", skills=[]) + + # Both should result in the same normalized URL + assert a2a_agent1.http_url == "http://example.com/agent1/" + assert a2a_agent2.http_url == "http://example.com/agent1/" + assert a2a_agent1.mount_path == "/agent1" + assert a2a_agent2.mount_path == "/agent1" + + +def test_serve_at_root_default_behavior(mock_strands_agent): + """Test default behavior extracts mount path from http_url.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + + assert server.mount_path == "/agent1" + assert server.http_url == "http://my-alb.com/agent1/" + + +def test_serve_at_root_overrides_mounting(mock_strands_agent): + """Test serve_at_root=True overrides automatic path mounting.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + + assert server.mount_path == "" # Should be empty despite path in URL + assert server.http_url == "http://my-alb.com/agent1/" # Public URL unchanged + + +def test_serve_at_root_with_no_path(mock_strands_agent): + """Test serve_at_root=True when no path in URL (redundant but valid).""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer(mock_strands_agent, host="localhost", port=8080, serve_at_root=True, skills=[]) + + assert server.mount_path == "" + assert server.http_url == "http://localhost:8080/" + + +def test_serve_at_root_complex_path(mock_strands_agent): + """Test serve_at_root=True with complex nested paths.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + server = A2AServer( + mock_strands_agent, http_url="http://api.example.com/v1/agents/my-agent", serve_at_root=True, skills=[] + ) + + assert server.mount_path == "" + assert server.http_url == "http://api.example.com/v1/agents/my-agent/" + + +def test_serve_at_root_fastapi_mounting_behavior(mock_strands_agent): + """Test FastAPI mounting behavior with serve_at_root.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Normal mounting + server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + app_mounted = server_mounted.to_fastapi_app() + client_mounted = TestClient(app_mounted) + + # Should work at mounted path + response = client_mounted.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + + # Should not work at root + response = client_mounted.get("/.well-known/agent.json") + assert response.status_code == 404 + + +def test_serve_at_root_fastapi_root_behavior(mock_strands_agent): + """Test FastAPI serve_at_root behavior.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Serve at root + server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + app_root = server_root.to_fastapi_app() + client_root = TestClient(app_root) + + # Should work at root + response = client_root.get("/.well-known/agent.json") + assert response.status_code == 200 + + # Should not work at mounted path (since we're serving at root) + response = client_root.get("/agent1/.well-known/agent.json") + assert response.status_code == 404 + + +def test_serve_at_root_starlette_behavior(mock_strands_agent): + """Test Starlette serve_at_root behavior.""" + from starlette.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Normal mounting + server_mounted = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", skills=[]) + app_mounted = server_mounted.to_starlette_app() + client_mounted = TestClient(app_mounted) + + # Should work at mounted path + response = client_mounted.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + + # Serve at root + server_root = A2AServer(mock_strands_agent, http_url="http://my-alb.com/agent1", serve_at_root=True, skills=[]) + app_root = server_root.to_starlette_app() + client_root = TestClient(app_root) + + # Should work at root + response = client_root.get("/.well-known/agent.json") + assert response.status_code == 200 + + +def test_serve_at_root_alb_scenarios(mock_strands_agent): + """Test common ALB deployment scenarios.""" + from fastapi.testclient import TestClient + + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # ALB with path preservation + server_preserved = A2AServer(mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", skills=[]) + app_preserved = server_preserved.to_fastapi_app() + client_preserved = TestClient(app_preserved) + + # Container receives /agent1/.well-known/agent.json + response = client_preserved.get("/agent1/.well-known/agent.json") + assert response.status_code == 200 + agent_data = response.json() + assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/" + + # ALB with path stripping + server_stripped = A2AServer( + mock_strands_agent, http_url="http://my-alb.amazonaws.com/agent1", serve_at_root=True, skills=[] + ) + app_stripped = server_stripped.to_fastapi_app() + client_stripped = TestClient(app_stripped) + + # Container receives /.well-known/agent.json (path stripped by ALB) + response = client_stripped.get("/.well-known/agent.json") + assert response.status_code == 200 + agent_data = response.json() + assert agent_data["url"] == "http://my-alb.amazonaws.com/agent1/" + + +def test_serve_at_root_edge_cases(mock_strands_agent): + """Test edge cases for serve_at_root parameter.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Root path in URL + server1 = A2AServer(mock_strands_agent, http_url="http://example.com/", skills=[]) + assert server1.mount_path == "" + + # serve_at_root should be redundant but not cause issues + server2 = A2AServer(mock_strands_agent, http_url="http://example.com/", serve_at_root=True, skills=[]) + assert server2.mount_path == "" + + # Multiple nested paths + server3 = A2AServer( + mock_strands_agent, http_url="http://api.example.com/v1/agents/team1/agent1", serve_at_root=True, skills=[] + ) + assert server3.mount_path == "" + assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/" From b30e7e6e41e7a2dce70d74e8c1753503959f3619 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Wed, 23 Jul 2025 15:20:28 -0400 Subject: [PATCH 07/41] fix: include agent trace into tool for agent as tools (#526) --- src/strands/telemetry/tracer.py | 2 +- src/strands/tools/executor.py | 37 ++++++++++++++++----------------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index eebffef29..802865189 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -273,7 +273,7 @@ def end_model_invoke_span( self._end_span(span, attributes, error) - def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Optional[Span]: + def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None, **kwargs: Any) -> Span: """Start a new span for a tool call. Args: diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 1214fa608..d90f9a5aa 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -5,7 +5,7 @@ import time from typing import Any, Optional, cast -from opentelemetry import trace +from opentelemetry import trace as trace_api from ..telemetry.metrics import EventLoopMetrics, Trace from ..telemetry.tracer import get_tracer @@ -23,7 +23,7 @@ async def run_tools( invalid_tool_use_ids: list[str], tool_results: list[ToolResult], cycle_trace: Trace, - parent_span: Optional[trace.Span] = None, + parent_span: Optional[trace_api.Span] = None, ) -> ToolGenerator: """Execute tools concurrently. @@ -53,24 +53,23 @@ async def work( tool_name = tool_use["name"] tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) tool_start_time = time.time() + with trace_api.use_span(tool_call_span): + try: + async for event in handler(tool_use): + worker_queue.put_nowait((worker_id, event)) + await worker_event.wait() + worker_event.clear() + + result = cast(ToolResult, event) + finally: + worker_queue.put_nowait((worker_id, stop_event)) + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) - try: - async for event in handler(tool_use): - worker_queue.put_nowait((worker_id, event)) - await worker_event.wait() - worker_event.clear() - - result = cast(ToolResult, event) - finally: - worker_queue.put_nowait((worker_id, stop_event)) - - tool_success = result.get("status") == "success" - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - if tool_call_span: tracer.end_tool_call_span(tool_call_span, result) return result From 8c5562575f8c6c26c2b2a18591d1d5926a96514a Mon Sep 17 00:00:00 2001 From: Davide Gallitelli Date: Mon, 28 Jul 2025 13:34:04 +0200 Subject: [PATCH 08/41] Support for Amazon SageMaker AI endpoints as Model Provider (#176) --- pyproject.toml | 18 +- src/strands/models/sagemaker.py | 600 +++++++++++++++++++++ tests/strands/models/test_sagemaker.py | 574 ++++++++++++++++++++ tests_integ/models/test_model_sagemaker.py | 76 +++ 4 files changed, 1262 insertions(+), 6 deletions(-) create mode 100644 src/strands/models/sagemaker.py create mode 100644 tests/strands/models/test_sagemaker.py create mode 100644 tests_integ/models/test_model_sagemaker.py diff --git a/pyproject.toml b/pyproject.toml index 765e815ef..745c80e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,8 +89,14 @@ writer = [ "writer-sdk>=2.2.0,<3.0.0" ] +sagemaker = [ + "boto3>=1.26.0,<2.0.0", + "botocore>=1.29.0,<2.0.0", + "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0" +] + a2a = [ - "a2a-sdk[sql]>=0.2.16,<1.0.0", + "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -136,7 +142,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.16,<1.0.0", + "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -148,7 +154,7 @@ all = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -171,7 +177,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -187,7 +193,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] @@ -315,4 +321,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] +] \ No newline at end of file diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py new file mode 100644 index 000000000..bb2db45a2 --- /dev/null +++ b/src/strands/models/sagemaker.py @@ -0,0 +1,600 @@ +"""Amazon SageMaker model provider.""" + +import json +import logging +import os +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union, cast + +import boto3 +from botocore.config import Config as BotocoreConfig +from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec +from .openai import OpenAIModel + +T = TypeVar("T", bound=BaseModel) + +logger = logging.getLogger(__name__) + + +@dataclass +class UsageMetadata: + """Usage metadata for the model. + + Attributes: + total_tokens: Total number of tokens used in the request + completion_tokens: Number of tokens used in the completion + prompt_tokens: Number of tokens used in the prompt + prompt_tokens_details: Additional information about the prompt tokens (optional) + """ + + total_tokens: int + completion_tokens: int + prompt_tokens: int + prompt_tokens_details: Optional[int] = 0 + + +@dataclass +class FunctionCall: + """Function call for the model. + + Attributes: + name: Name of the function to call + arguments: Arguments to pass to the function + """ + + name: Union[str, dict[Any, Any]] + arguments: Union[str, dict[Any, Any]] + + def __init__(self, **kwargs: dict[str, str]): + """Initialize function call. + + Args: + **kwargs: Keyword arguments for the function call. + """ + self.name = kwargs.get("name", "") + self.arguments = kwargs.get("arguments", "") + + +@dataclass +class ToolCall: + """Tool call for the model object. + + Attributes: + id: Tool call ID + type: Tool call type + function: Tool call function + """ + + id: str + type: Literal["function"] + function: FunctionCall + + def __init__(self, **kwargs: dict): + """Initialize tool call object. + + Args: + **kwargs: Keyword arguments for the tool call. + """ + self.id = str(kwargs.get("id", "")) + self.type = "function" + self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""})) + + +class SageMakerAIModel(OpenAIModel): + """Amazon SageMaker model provider implementation.""" + + client: SageMakerRuntimeClient # type: ignore[assignment] + + class SageMakerAIPayloadSchema(TypedDict, total=False): + """Payload schema for the Amazon SageMaker AI model. + + Attributes: + max_tokens: Maximum number of tokens to generate in the completion + stream: Whether to stream the response + temperature: Sampling temperature to use for the model (optional) + top_p: Nucleus sampling parameter (optional) + top_k: Top-k sampling parameter (optional) + stop: List of stop sequences to use for the model (optional) + tool_results_as_user_messages: Convert tool result to user messages (optional) + additional_args: Additional request parameters, as supported by https://bit.ly/djl-lmi-request-schema + """ + + max_tokens: int + stream: bool + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + stop: Optional[list[str]] + tool_results_as_user_messages: Optional[bool] + additional_args: Optional[dict[str, Any]] + + class SageMakerAIEndpointConfig(TypedDict, total=False): + """Configuration options for SageMaker models. + + Attributes: + endpoint_name: The name of the SageMaker endpoint to invoke + inference_component_name: The name of the inference component to use + + additional_args: Other request parameters, as supported by https://bit.ly/sagemaker-invoke-endpoint-params + """ + + endpoint_name: str + region_name: str + inference_component_name: Union[str, None] + target_model: Union[Optional[str], None] + target_variant: Union[Optional[str], None] + additional_args: Optional[dict[str, Any]] + + def __init__( + self, + endpoint_config: SageMakerAIEndpointConfig, + payload_config: SageMakerAIPayloadSchema, + boto_session: Optional[boto3.Session] = None, + boto_client_config: Optional[BotocoreConfig] = None, + ): + """Initialize provider instance. + + Args: + endpoint_config: Endpoint configuration for SageMaker. + payload_config: Payload configuration for the model. + boto_session: Boto Session to use when calling the SageMaker Runtime. + boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client. + """ + payload_config.setdefault("stream", True) + payload_config.setdefault("tool_results_as_user_messages", False) + self.endpoint_config = dict(endpoint_config) + self.payload_config = dict(payload_config) + logger.debug( + "endpoint_config=<%s> payload_config=<%s> | initializing", self.endpoint_config, self.payload_config + ) + + region = self.endpoint_config.get("region_name") or os.getenv("AWS_REGION") or "us-west-2" + session = boto_session or boto3.Session(region_name=str(region)) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self.client = session.client( + service_name="sagemaker-runtime", + config=client_config, + ) + + @override + def update_config(self, **endpoint_config: Unpack[SageMakerAIEndpointConfig]) -> None: # type: ignore[override] + """Update the Amazon SageMaker model configuration with the provided arguments. + + Args: + **endpoint_config: Configuration overrides. + """ + self.endpoint_config.update(endpoint_config) + + @override + def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: ignore[override] + """Get the Amazon SageMaker model configuration. + + Returns: + The Amazon SageMaker model configuration. + """ + return cast(SageMakerAIModel.SageMakerAIEndpointConfig, self.endpoint_config) + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Amazon SageMaker chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Amazon SageMaker chat streaming request. + """ + formatted_messages = self.format_request_messages(messages, system_prompt) + + payload = { + "messages": formatted_messages, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + # Add payload configuration parameters + **{ + k: v + for k, v in self.payload_config.items() + if k not in ["additional_args", "tool_results_as_user_messages"] + }, + } + + # Remove tools and tool_choice if tools = [] + if not payload["tools"]: + payload.pop("tools") + payload.pop("tool_choice", None) + else: + # Ensure the model can use tools when available + payload["tool_choice"] = "auto" + + for message in payload["messages"]: # type: ignore + # Assistant message must have either content or tool_calls, but not both + if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []: + message.pop("content", None) + if message.get("role") == "tool" and self.payload_config.get("tool_results_as_user_messages", False): + # Convert tool message to user message + tool_call_id = message.get("tool_call_id", "ABCDEF") + content = message.get("content", "") + message = {"role": "user", "content": f"Tool call ID '{tool_call_id}' returned: {content}"} + # Cannot have both reasoning_text and text - if "text", content becomes an array of content["text"] + for c in message.get("content", []): + if "text" in c: + message["content"] = [c] + break + # Cast message content to string for TGI compatibility + # message["content"] = str(message.get("content", "")) + + logger.info("payload=<%s>", json.dumps(payload, indent=2)) + # Format the request according to the SageMaker Runtime API requirements + request = { + "EndpointName": self.endpoint_config["endpoint_name"], + "Body": json.dumps(payload), + "ContentType": "application/json", + "Accept": "application/json", + } + + # Add optional SageMaker parameters if provided + if self.endpoint_config.get("inference_component_name"): + request["InferenceComponentName"] = self.endpoint_config["inference_component_name"] + if self.endpoint_config.get("target_model"): + request["TargetModel"] = self.endpoint_config["target_model"] + if self.endpoint_config.get("target_variant"): + request["TargetVariant"] = self.endpoint_config["target_variant"] + + # Add additional args if provided + if self.endpoint_config.get("additional_args"): + request.update(self.endpoint_config["additional_args"].__dict__) + + print(json.dumps(request["Body"], indent=2)) + + return request + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the SageMaker model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + """ + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking model") + try: + if self.payload_config.get("stream", True): + response = self.client.invoke_endpoint_with_response_stream(**request) + + # Message start + yield self.format_chunk({"chunk_type": "message_start"}) + + # Parse the content + finish_reason = "" + partial_content = "" + tool_calls: dict[int, list[Any]] = {} + has_text_content = False + text_content_started = False + reasoning_content_started = False + + for event in response["Body"]: + chunk = event["PayloadPart"]["Bytes"].decode("utf-8") + partial_content += chunk[6:] if chunk.startswith("data: ") else chunk # TGI fix + logger.info("chunk=<%s>", partial_content) + try: + content = json.loads(partial_content) + partial_content = "" + choice = content["choices"][0] + logger.info("choice=<%s>", json.dumps(choice, indent=2)) + + # Handle text content + if choice["delta"].get("content", None): + if not text_content_started: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + text_content_started = True + has_text_content = True + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": choice["delta"]["content"], + } + ) + + # Handle reasoning content + if choice["delta"].get("reasoning_content", None): + if not reasoning_content_started: + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "reasoning_content"} + ) + reasoning_content_started = True + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice["delta"]["reasoning_content"], + } + ) + + # Handle tool calls + generated_tool_calls = choice["delta"].get("tool_calls", []) + if not isinstance(generated_tool_calls, list): + generated_tool_calls = [generated_tool_calls] + for tool_call in generated_tool_calls: + tool_calls.setdefault(tool_call["index"], []).append(tool_call) + + if choice["finish_reason"] is not None: + finish_reason = choice["finish_reason"] + break + + if choice.get("usage", None): + yield self.format_chunk( + {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])} + ) + + except json.JSONDecodeError: + # Continue accumulating content until we have valid JSON + continue + + # Close reasoning content if it was started + if reasoning_content_started: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + + # Close text content if it was started + if text_content_started: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Handle tool calling + logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2)) + for tool_deltas in tool_calls.values(): + if not tool_deltas[0]["function"].get("name", None): + raise Exception("The model did not provide a tool name.") + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])} + ) + for tool_delta in tool_deltas: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + # If no content was generated at all, ensure we have empty text content + if not has_text_content and not tool_calls: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Message close + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + else: + # Not all SageMaker AI models support streaming! + response = self.client.invoke_endpoint(**request) # type: ignore[assignment] + final_response_json = json.loads(response["Body"].read().decode("utf-8")) # type: ignore[attr-defined] + logger.info("response=<%s>", json.dumps(final_response_json, indent=2)) + + # Obtain the key elements from the response + message = final_response_json["choices"][0]["message"] + message_stop_reason = final_response_json["choices"][0]["finish_reason"] + + # Message start + yield self.format_chunk({"chunk_type": "message_start"}) + + # Handle text + if message.get("content", ""): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": message["content"]} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Handle reasoning content + if message.get("reasoning_content", None): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": message["reasoning_content"], + } + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) + + # Handle the tool calling, if any + if message.get("tool_calls", None) or message_stop_reason == "tool_calls": + if not isinstance(message["tool_calls"], list): + message["tool_calls"] = [message["tool_calls"]] + for tool_call in message["tool_calls"]: + # if arguments of tool_call is not str, cast it + if not isinstance(tool_call["function"]["arguments"], str): + tool_call["function"]["arguments"] = json.dumps(tool_call["function"]["arguments"]) + yield self.format_chunk( + {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)} + ) + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)} + ) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + message_stop_reason = "tool_calls" + + # Message close + yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason}) + # Handle usage metadata + if final_response_json.get("usage", None): + yield self.format_chunk( + {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))} + ) + except ( + self.client.exceptions.InternalFailure, + self.client.exceptions.ServiceUnavailable, + self.client.exceptions.ValidationError, + self.client.exceptions.ModelError, + self.client.exceptions.InternalDependencyException, + self.client.exceptions.ModelNotReadyException, + ) as e: + logger.error("SageMaker error: %s", str(e)) + raise e + + logger.debug("finished streaming response from model") + + @override + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format a SageMaker compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + SageMaker compatible tool message with content as a string. + """ + # Convert content blocks to a simple string for SageMaker compatibility + content_parts = [] + for content in tool_result["content"]: + if "json" in content: + content_parts.append(json.dumps(content["json"])) + elif "text" in content: + content_parts.append(content["text"]) + else: + # Handle other content types by converting to string + content_parts.append(str(content)) + + content_string = " ".join(content_parts) + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": content_string, # String instead of list + } + + @override + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format a content block. + + Args: + content: Message content. + + Returns: + Formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a SageMaker-compatible format. + """ + # if "text" in content and not isinstance(content["text"], str): + # return {"type": "text", "text": str(content["text"])} + + if "reasoningContent" in content and content["reasoningContent"]: + return { + "signature": content["reasoningContent"].get("reasoningText", {}).get("signature", ""), + "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), + "type": "thinking", + } + elif not content.get("reasoningContent", None): + content.pop("reasoningContent", None) + + if "video" in content: + return { + "type": "video_url", + "video_url": { + "detail": "auto", + "url": content["video"]["source"]["bytes"], + }, + } + + return super().format_request_message_content(content) + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + """ + # Format the request for structured output + request = self.format_request(prompt, system_prompt=system_prompt) + + # Parse the payload to add response format + payload = json.loads(request["Body"]) + payload["response_format"] = { + "type": "json_schema", + "json_schema": {"name": output_model.__name__, "schema": output_model.model_json_schema(), "strict": True}, + } + request["Body"] = json.dumps(payload) + + try: + # Use non-streaming mode for structured output + response = self.client.invoke_endpoint(**request) + final_response_json = json.loads(response["Body"].read().decode("utf-8")) + + # Extract the structured content + message = final_response_json["choices"][0]["message"] + + if message.get("content"): + try: + # Parse the JSON content and create the output model instance + content_data = json.loads(message["content"]) + parsed_output = output_model(**content_data) + yield {"output": parsed_output} + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse structured output: {e}") from e + else: + raise ValueError("No content found in SageMaker response") + + except ( + self.client.exceptions.InternalFailure, + self.client.exceptions.ServiceUnavailable, + self.client.exceptions.ValidationError, + self.client.exceptions.ModelError, + self.client.exceptions.InternalDependencyException, + self.client.exceptions.ModelNotReadyException, + ) as e: + logger.error("SageMaker structured output error: %s", str(e)) + raise ValueError(f"SageMaker structured output error: {str(e)}") from e diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py new file mode 100644 index 000000000..ba395b2d6 --- /dev/null +++ b/tests/strands/models/test_sagemaker.py @@ -0,0 +1,574 @@ +"""Tests for the Amazon SageMaker model provider.""" + +import json +import unittest.mock +from typing import Any, Dict, List + +import boto3 +import pytest +from botocore.config import Config as BotocoreConfig + +from strands.models.sagemaker import ( + FunctionCall, + SageMakerAIModel, + ToolCall, + UsageMetadata, +) +from strands.types.content import Messages +from strands.types.tools import ToolSpec + + +@pytest.fixture +def boto_session(): + """Mock boto3 session.""" + with unittest.mock.patch.object(boto3, "Session") as mock_session: + yield mock_session.return_value + + +@pytest.fixture +def sagemaker_client(boto_session): + """Mock SageMaker runtime client.""" + return boto_session.client.return_value + + +@pytest.fixture +def endpoint_config() -> Dict[str, Any]: + """Default endpoint configuration for tests.""" + return { + "endpoint_name": "test-endpoint", + "inference_component_name": "test-component", + "region_name": "us-east-1", + } + + +@pytest.fixture +def payload_config() -> Dict[str, Any]: + """Default payload configuration for tests.""" + return { + "max_tokens": 1024, + "temperature": 0.7, + "stream": True, + } + + +@pytest.fixture +def model(boto_session, endpoint_config, payload_config): + """SageMaker model instance with mocked boto session.""" + return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session) + + +@pytest.fixture +def messages() -> Messages: + """Sample messages for testing.""" + return [{"role": "user", "content": [{"text": "What is the capital of France?"}]}] + + +@pytest.fixture +def tool_specs() -> List[ToolSpec]: + """Sample tool specifications for testing.""" + return [ + { + "name": "get_weather", + "description": "Get the weather for a location", + "inputSchema": { + "json": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + }, + } + ] + + +@pytest.fixture +def system_prompt() -> str: + """Sample system prompt for testing.""" + return "You are a helpful assistant." + + +class TestSageMakerAIModel: + """Test suite for SageMakerAIModel.""" + + def test_init_default(self, boto_session): + """Test initialization with default parameters.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024} + model = SageMakerAIModel( + endpoint_config=endpoint_config, payload_config=payload_config, boto_session=boto_session + ) + + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.payload_config.get("stream", True) is True + + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + def test_init_with_all_params(self, boto_session): + """Test initialization with all parameters.""" + endpoint_config = { + "endpoint_name": "test-endpoint", + "inference_component_name": "test-component", + "region_name": "us-west-2", + } + payload_config = { + "stream": False, + "max_tokens": 1024, + "temperature": 0.7, + } + client_config = BotocoreConfig(user_agent_extra="test-agent") + + model = SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + boto_client_config=client_config, + ) + + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.endpoint_config["inference_component_name"] == "test-component" + assert model.payload_config["stream"] is False + assert model.payload_config["max_tokens"] == 1024 + assert model.payload_config["temperature"] == 0.7 + + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + def test_init_with_client_config(self, boto_session): + """Test initialization with client configuration.""" + endpoint_config = {"endpoint_name": "test-endpoint", "region_name": "us-east-1"} + payload_config = {"max_tokens": 1024} + client_config = BotocoreConfig(user_agent_extra="test-agent") + + SageMakerAIModel( + endpoint_config=endpoint_config, + payload_config=payload_config, + boto_session=boto_session, + boto_client_config=client_config, + ) + + # Verify client was created with a config that includes our user agent + boto_session.client.assert_called_once_with( + service_name="sagemaker-runtime", + config=unittest.mock.ANY, + ) + + # Get the actual config passed to client + actual_config = boto_session.client.call_args[1]["config"] + assert "strands-agents" in actual_config.user_agent_extra + assert "test-agent" in actual_config.user_agent_extra + + def test_update_config(self, model): + """Test updating model configuration.""" + new_config = {"target_model": "new-model", "target_variant": "new-variant"} + model.update_config(**new_config) + + assert model.endpoint_config["target_model"] == "new-model" + assert model.endpoint_config["target_variant"] == "new-variant" + # Original values should be preserved + assert model.endpoint_config["endpoint_name"] == "test-endpoint" + assert model.endpoint_config["inference_component_name"] == "test-component" + + def test_get_config(self, model, endpoint_config): + """Test getting model configuration.""" + config = model.get_config() + assert config == model.endpoint_config + assert isinstance(config, dict) + + # def test_format_request_messages_with_system_prompt(self, model): + # """Test formatting request messages with system prompt.""" + # messages = [{"role": "user", "content": "Hello"}] + # system_prompt = "You are a helpful assistant." + + # formatted_messages = model.format_request_messages(messages, system_prompt) + + # assert len(formatted_messages) == 2 + # assert formatted_messages[0]["role"] == "system" + # assert formatted_messages[0]["content"] == system_prompt + # assert formatted_messages[1]["role"] == "user" + # assert formatted_messages[1]["content"] == "Hello" + + # def test_format_request_messages_with_tool_calls(self, model): + # """Test formatting request messages with tool calls.""" + # messages = [ + # {"role": "user", "content": "Hello"}, + # { + # "role": "assistant", + # "content": None, + # "tool_calls": [{"id": "123", "type": "function", "function": {"name": "test", "arguments": "{}"}}], + # }, + # ] + + # formatted_messages = model.format_request_messages(messages, None) + + # assert len(formatted_messages) == 2 + # assert formatted_messages[0]["role"] == "user" + # assert formatted_messages[1]["role"] == "assistant" + # assert "content" not in formatted_messages[1] + # assert "tool_calls" in formatted_messages[1] + + # def test_format_request(self, model, messages, tool_specs, system_prompt): + # """Test formatting a request with all parameters.""" + # request = model.format_request(messages, tool_specs, system_prompt) + + # assert request["EndpointName"] == "test-endpoint" + # assert request["InferenceComponentName"] == "test-component" + # assert request["ContentType"] == "application/json" + # assert request["Accept"] == "application/json" + + # payload = json.loads(request["Body"]) + # assert "messages" in payload + # assert len(payload["messages"]) > 0 + # assert "tools" in payload + # assert len(payload["tools"]) == 1 + # assert payload["tools"][0]["type"] == "function" + # assert payload["tools"][0]["function"]["name"] == "get_weather" + # assert payload["max_tokens"] == 1024 + # assert payload["temperature"] == 0.7 + # assert payload["stream"] is True + + # def test_format_request_without_tools(self, model, messages, system_prompt): + # """Test formatting a request without tools.""" + # request = model.format_request(messages, None, system_prompt) + + # payload = json.loads(request["Body"]) + # assert "tools" in payload + # assert payload["tools"] == [] + + @pytest.mark.asyncio + async def test_stream_with_streaming_enabled(self, sagemaker_client, model, messages): + """Test streaming response with streaming enabled.""" + # Mock the response from SageMaker + mock_response = { + "Body": [ + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": {"content": "Paris is the capital of France."}, + "finish_reason": None, + } + ] + } + ).encode("utf-8") + } + }, + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": {"content": " It is known for the Eiffel Tower."}, + "finish_reason": "stop", + } + ] + } + ).encode("utf-8") + } + }, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) >= 5 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "end_turn" + + sagemaker_client.invoke_endpoint_with_response_stream.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_with_tool_calls(self, sagemaker_client, model, messages): + """Test streaming response with tool calls.""" + # Mock the response from SageMaker with tool calls + mock_response = { + "Body": [ + { + "PayloadPart": { + "Bytes": json.dumps( + { + "choices": [ + { + "delta": { + "content": None, + "tool_calls": [ + { + "index": 0, + "id": "tool123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "Paris"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + ).encode("utf-8") + } + } + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + # Verify the response contains tool call events + assert len(response) >= 4 + assert response[0] == {"messageStart": {"role": "assistant"}} + + message_stop = next((e for e in response if "messageStop" in e), None) + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "tool_use" + + # Find tool call events + tool_start = next( + ( + e + for e in response + if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse") + ), + None, + ) + tool_delta = next( + ( + e + for e in response + if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse") + ), + None, + ) + tool_stop = next((e for e in response if "contentBlockStop" in e), None) + + assert tool_start is not None + assert tool_delta is not None + assert tool_stop is not None + + # Verify tool call data + tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["toolUseId"] == "tool123" + assert tool_use_data["name"] == "get_weather" + + @pytest.mark.asyncio + async def test_stream_with_partial_json(self, sagemaker_client, model, messages): + """Test streaming response with partial JSON chunks.""" + # Mock the response from SageMaker with split JSON + mock_response = { + "Body": [ + {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + ] + } + sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) == 5 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + assert message_stop["messageStop"]["stopReason"] == "end_turn" + + # Verify content + text_delta = content_delta["contentBlockDelta"]["delta"]["text"] + assert text_delta == "Paris is the capital of France." + + @pytest.mark.asyncio + async def test_stream_non_streaming(self, sagemaker_client, model, messages): + """Test non-streaming response.""" + # Configure model for non-streaming + model.payload_config["stream"] = False + + # Mock the response from SageMaker + mock_response = {"Body": unittest.mock.MagicMock()} + mock_response["Body"].read.return_value = json.dumps( + { + "choices": [ + { + "message": {"content": "Paris is the capital of France.", "tool_calls": None}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0}, + } + ).encode("utf-8") + + sagemaker_client.invoke_endpoint.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + assert len(response) >= 6 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find content events + content_start = next((e for e in response if "contentBlockStart" in e), None) + content_delta = next((e for e in response if "contentBlockDelta" in e), None) + content_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert content_start is not None + assert content_delta is not None + assert content_stop is not None + assert message_stop is not None + + # Verify content + text_delta = content_delta["contentBlockDelta"]["delta"]["text"] + assert text_delta == "Paris is the capital of France." + + sagemaker_client.invoke_endpoint.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_non_streaming_with_tool_calls(self, sagemaker_client, model, messages): + """Test non-streaming response with tool calls.""" + # Configure model for non-streaming + model.payload_config["stream"] = False + + # Mock the response from SageMaker with tool calls + mock_response = {"Body": unittest.mock.MagicMock()} + mock_response["Body"].read.return_value = json.dumps( + { + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "tool123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "prompt_tokens_details": 0}, + } + ).encode("utf-8") + + sagemaker_client.invoke_endpoint.return_value = mock_response + + response = [chunk async for chunk in model.stream(messages)] + + # Verify basic structure + assert len(response) >= 6 + assert response[0] == {"messageStart": {"role": "assistant"}} + + # Find tool call events + tool_start = next( + ( + e + for e in response + if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse") + ), + None, + ) + tool_delta = next( + ( + e + for e in response + if "contentBlockDelta" in e and e.get("contentBlockDelta", {}).get("delta", {}).get("toolUse") + ), + None, + ) + tool_stop = next((e for e in response if "contentBlockStop" in e), None) + message_stop = next((e for e in response if "messageStop" in e), None) + + assert tool_start is not None + assert tool_delta is not None + assert tool_stop is not None + assert message_stop is not None + + # Verify tool call data + tool_use_data = tool_start["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["toolUseId"] == "tool123" + assert tool_use_data["name"] == "get_weather" + + # Verify metadata + metadata = next((e for e in response if "metadata" in e), None) + assert metadata is not None + usage_data = metadata["metadata"]["usage"] + assert usage_data["totalTokens"] == 30 + + +class TestDataClasses: + """Test suite for data classes.""" + + def test_usage_metadata(self): + """Test UsageMetadata dataclass.""" + usage = UsageMetadata(total_tokens=100, completion_tokens=30, prompt_tokens=70, prompt_tokens_details=5) + + assert usage.total_tokens == 100 + assert usage.completion_tokens == 30 + assert usage.prompt_tokens == 70 + assert usage.prompt_tokens_details == 5 + + def test_function_call(self): + """Test FunctionCall dataclass.""" + func = FunctionCall(name="get_weather", arguments='{"location": "Paris"}') + + assert func.name == "get_weather" + assert func.arguments == '{"location": "Paris"}' + + # Test initialization with kwargs + func2 = FunctionCall(**{"name": "get_time", "arguments": '{"timezone": "UTC"}'}) + + assert func2.name == "get_time" + assert func2.arguments == '{"timezone": "UTC"}' + + def test_tool_call(self): + """Test ToolCall dataclass.""" + # Create a tool call using kwargs directly + tool = ToolCall( + id="tool123", type="function", function={"name": "get_weather", "arguments": '{"location": "Paris"}'} + ) + + assert tool.id == "tool123" + assert tool.type == "function" + assert tool.function.name == "get_weather" + assert tool.function.arguments == '{"location": "Paris"}' + + # Test initialization with kwargs + tool2 = ToolCall( + **{ + "id": "tool456", + "type": "function", + "function": {"name": "get_time", "arguments": '{"timezone": "UTC"}'}, + } + ) + + assert tool2.id == "tool456" + assert tool2.type == "function" + assert tool2.function.name == "get_time" + assert tool2.function.arguments == '{"timezone": "UTC"}' diff --git a/tests_integ/models/test_model_sagemaker.py b/tests_integ/models/test_model_sagemaker.py new file mode 100644 index 000000000..62362e299 --- /dev/null +++ b/tests_integ/models/test_model_sagemaker.py @@ -0,0 +1,76 @@ +import os + +import pytest + +import strands +from strands import Agent +from strands.models.sagemaker import SageMakerAIModel + + +@pytest.fixture +def model(): + endpoint_config = SageMakerAIModel.SageMakerAIEndpointConfig( + endpoint_name=os.getenv("SAGEMAKER_ENDPOINT_NAME", ""), region_name="us-east-1" + ) + payload_config = SageMakerAIModel.SageMakerAIPayloadSchema(max_tokens=1024, temperature=0.7, stream=False) + return SageMakerAIModel(endpoint_config=endpoint_config, payload_config=payload_config) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time(location: str) -> str: + """Get the current time for a location.""" + return f"The time in {location} is 12:00 PM" + + @strands.tool + def tool_weather(location: str) -> str: + """Get the current weather for a location.""" + return f"The weather in {location} is sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant that provides concise answers." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +def test_agent_with_tools(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert "12:00" in text and "sunny" in text + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +def test_agent_without_tools(model, system_prompt): + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello, how are you?") + + assert result.message["content"][0]["text"] + assert len(result.message["content"][0]["text"]) > 0 + + +@pytest.mark.skipif( + "SAGEMAKER_ENDPOINT_NAME" not in os.environ, + reason="SAGEMAKER_ENDPOINT_NAME environment variable missing", +) +@pytest.mark.parametrize("location", ["Tokyo", "London", "Sydney"]) +def test_agent_different_locations(agent, location): + result = agent(f"What is the weather in {location}?") + text = result.message["content"][0]["text"].lower() + + assert location.lower() in text and "sunny" in text From 3f4c3a35ce14800e4852998e0c2b68f90295ffb7 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Mon, 28 Jul 2025 10:23:43 -0400 Subject: [PATCH 09/41] fix: Remove leftover print statement from sagemaker model provider (#553) --- src/strands/models/sagemaker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index bb2db45a2..9cfe27d9e 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -274,8 +274,6 @@ def format_request( if self.endpoint_config.get("additional_args"): request.update(self.endpoint_config["additional_args"].__dict__) - print(json.dumps(request["Body"], indent=2)) - return request @override From bdc893bbae711c1af301e6f18901cb30814789a0 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 29 Jul 2025 14:41:57 -0400 Subject: [PATCH 10/41] [Feat] Update structured output error message (#563) * Update bedrock.py * Update anthropic.py --- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index eb72becfd..0d734b762 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError("No valid tool use or tool use input was found in the Anthropic response.") + raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 679f1ea3d..cf1e4d3a9 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -584,7 +584,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError("No valid tool use or tool use input was found in the Bedrock response.") + raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") content = messages["content"] output_response: dict[str, Any] | None = None From 4e0e0a648c7e441ce15eacca213b7b65e982fd3b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 29 Jul 2025 18:03:19 -0400 Subject: [PATCH 11/41] feat(mcp): retain structured content in the AgentTool response (#528) --- pyproject.toml | 2 +- src/strands/models/bedrock.py | 53 +++++++++- src/strands/tools/mcp/mcp_client.py | 49 +++++++--- src/strands/tools/mcp/mcp_types.py | 20 ++++ tests/strands/models/test_bedrock.py | 96 ++++++++++++------- tests/strands/tools/mcp/test_mcp_client.py | 67 +++++++++++++ tests_integ/echo_server.py | 16 +++- tests_integ/test_mcp_client.py | 77 +++++++++++++++ ...cp_client_structured_content_with_hooks.py | 65 +++++++++++++ 9 files changed, 389 insertions(+), 56 deletions(-) create mode 100644 tests_integ/test_mcp_client_structured_content_with_hooks.py diff --git a/pyproject.toml b/pyproject.toml index 745c80e0c..095a38cb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", - "mcp>=1.8.0,<2.0.0", + "mcp>=1.11.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index cf1e4d3a9..9b36b4244 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -17,10 +17,10 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec -from ..types.content import Messages +from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolResult, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -181,7 +181,7 @@ def format_request( """ return { "modelId": self.config["model_id"], - "messages": messages, + "messages": self._format_bedrock_messages(messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), @@ -246,6 +246,53 @@ def format_request( ), } + def _format_bedrock_messages(self, messages: Messages) -> Messages: + """Format messages for Bedrock API compatibility. + + This function ensures messages conform to Bedrock's expected format by: + - Cleaning tool result content blocks by removing additional fields that may be + useful for retaining information in hooks but would cause Bedrock validation + exceptions when presented with unexpected fields + - Ensuring all message content blocks are properly formatted for the Bedrock API + + Args: + messages: List of messages to format + + Returns: + Messages formatted for Bedrock API compatibility + + Note: + Bedrock will throw validation exceptions when presented with additional + unexpected fields in tool result blocks. + https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + """ + cleaned_messages = [] + + for message in messages: + cleaned_content: list[ContentBlock] = [] + + for content_block in message["content"]: + if "toolResult" in content_block: + # Create a new content block with only the cleaned toolResult + tool_result: ToolResult = content_block["toolResult"] + + # Keep only the required fields for Bedrock + cleaned_tool_result = ToolResult( + content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"] + ) + + cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} + cleaned_content.append(cleaned_block) + else: + # Keep other content blocks as-is + cleaned_content.append(content_block) + + # Create new message with cleaned content + cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) + cleaned_messages.append(cleaned_message) + + return cleaned_messages + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 4cf4e1f85..784636fd0 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -26,9 +26,9 @@ from ...types import PaginatedList from ...types.exceptions import MCPClientInitializationError from ...types.media import ImageFormat -from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus +from ...types.tools import ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool -from .mcp_types import MCPTransport +from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -57,7 +57,8 @@ class MCPClient: It handles the creation, initialization, and cleanup of MCP connections. The connection runs in a background thread to avoid blocking the main application thread - while maintaining communication with the MCP service. + while maintaining communication with the MCP service. When structured content is available + from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ def __init__(self, transport_callable: Callable[[], MCPTransport]): @@ -170,11 +171,13 @@ def call_tool_sync( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the ToolResult format. If the MCP tool returns + structured content, it will be included as the last item in the content array + of the returned ToolResult. Args: tool_use_id: Unique identifier for this tool use @@ -183,7 +186,7 @@ def call_tool_sync( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -205,11 +208,11 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the MCPToolResult format. Args: tool_use_id: Unique identifier for this tool use @@ -218,7 +221,7 @@ async def call_tool_async( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -235,15 +238,27 @@ async def _call_tool_async() -> MCPCallToolResult: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) - def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult: + def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: """Create error ToolResult with consistent logging.""" - return ToolResult( + return MCPToolResult( status="error", toolUseId=tool_use_id, content=[{"text": f"Tool execution failed: {str(exception)}"}], ) - def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult: + def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult: + """Maps MCP tool result to the agent's MCPToolResult format. + + This method processes the content from the MCP tool call result and converts it to the format + expected by the framework. + + Args: + tool_use_id: Unique identifier for this tool use + call_tool_result: The result from the MCP tool call + + Returns: + MCPToolResult: The converted tool result + """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) mapped_content = [ @@ -254,7 +269,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes status: ToolResultStatus = "error" if call_tool_result.isError else "success" self._log_debug_with_thread("tool execution completed with status: %s", status) - return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content) + result = MCPToolResult( + status=status, + toolUseId=tool_use_id, + content=mapped_content, + ) + if call_tool_result.structuredContent: + result["structuredContent"] = call_tool_result.structuredContent + + return result async def _async_background_thread(self) -> None: """Asynchronous method that runs in the background thread to manage the MCP connection. diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 30defc585..5fafed5dc 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,11 +1,15 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager +from typing import Any, Dict from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback from mcp.shared.memory import MessageStream from mcp.shared.message import SessionMessage +from typing_extensions import NotRequired + +from strands.types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts @@ -41,3 +45,19 @@ async def my_transport_implementation(): MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback ] MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] + + +class MCPToolResult(ToolResult): + """Result of an MCP tool execution. + + Extends the base ToolResult with MCP-specific structured content support. + The structuredContent field contains optional JSON data returned by MCP tools + that provides structured results beyond the standard text/image/document content. + + Attributes: + structuredContent: Optional JSON object containing structured data returned + by the MCP tool. This allows MCP tools to return complex data structures + that can be processed programmatically by agents or other tools. + """ + + structuredContent: NotRequired[Dict[str, Any]] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 47e028cb9..0a2846adf 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -13,6 +13,7 @@ from strands.models import BedrockModel from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION from strands.types.exceptions import ModelThrottledException +from strands.types.tools import ToolSpec @pytest.fixture @@ -51,7 +52,7 @@ def model(bedrock_client, model_id): @pytest.fixture def messages(): - return [{"role": "user", "content": {"text": "test"}}] + return [{"role": "user", "content": [{"text": "test"}]}] @pytest.fixture @@ -90,8 +91,12 @@ def inference_config(): @pytest.fixture -def tool_spec(): - return {"t1": 1} +def tool_spec() -> ToolSpec: + return { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + } @pytest.fixture @@ -750,7 +755,7 @@ async def test_stream_output_no_guardrail_redact( @pytest.mark.asyncio -async def test_stream_with_streaming_false(bedrock_client, alist): +async def test_stream_with_streaming_false(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -759,8 +764,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -776,7 +780,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): +async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -790,8 +794,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -808,7 +811,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): +async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -828,8 +831,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -847,7 +849,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_and_reasoning_no_signature(bedrock_client, alist): +async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -867,8 +869,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -884,7 +885,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist): +async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -895,8 +896,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -919,7 +919,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client @pytest.mark.asyncio -async def test_stream_input_guardrails(bedrock_client, alist): +async def test_stream_input_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -937,8 +937,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -970,7 +969,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails(bedrock_client, alist): +async def test_stream_output_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -989,8 +988,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1024,7 +1022,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): +async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -1043,8 +1041,7 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1101,7 +1098,7 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_client_error(bedrock_client, model, alist): +async def test_add_note_on_client_error(bedrock_client, model, alist, messages): """Test that add_note is called on ClientError with region and model ID information.""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1109,13 +1106,13 @@ async def test_add_note_on_client_error(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] @pytest.mark.asyncio -async def test_no_add_note_when_not_available(bedrock_client, model, alist): +async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1123,12 +1120,12 @@ async def test_no_add_note_when_not_available(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError): - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_access_denied_exception(bedrock_client, model, alist): +async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for AccessDeniedException.""" # Mock the client error response for access denied error_response = { @@ -1142,7 +1139,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", @@ -1154,7 +1151,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist): +async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" # Mock the client error response for validation exception error_response = { @@ -1170,7 +1167,7 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", @@ -1202,3 +1199,32 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "invoking model" in log_text assert "got response from model" in log_text assert "finished streaming response from model" in log_text + + +def test_format_request_cleans_tool_result_content_blocks(model, model_id): + """Test that format_request cleans toolResult blocks by removing extra fields.""" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + "extraField": "should be removed", + "mcpMetadata": {"server": "test"}, + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + # Verify toolResult only contains allowed fields in the formatted request + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + assert tool_result == expected + assert "extraField" not in tool_result + assert "mcpMetadata" not in tool_result diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 6a2fdd00c..3d3792c71 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -8,6 +8,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError @@ -129,6 +130,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ assert result["toolUseId"] == "test-123" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "Test message" + # No structured content should be present when not provided by MCP + assert result.get("structuredContent") is None def test_call_tool_sync_session_not_active(): @@ -139,6 +142,31 @@ def test_call_tool_sync_session_not_active(): client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) +def test_call_tool_sync_with_structured_content(mock_transport, mock_session): + """Test that call_tool_sync correctly handles structured content.""" + mock_content = MCPTextContent(type="text", text="Test message") + structured_content = {"result": 42, "status": "completed"} + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, content=[mock_content], structuredContent=structured_content + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + # Content should only contain the text content, not the structured content + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # Structured content should be in its own field + assert "structuredContent" in result + assert result["structuredContent"] == structured_content + assert result["structuredContent"]["result"] == 42 + assert result["structuredContent"]["status"] == "completed" + + def test_call_tool_sync_exception(mock_transport, mock_session): """Test that call_tool_sync correctly handles exceptions.""" mock_session.call_tool.side_effect = Exception("Test exception") @@ -312,6 +340,45 @@ def test_enter_with_initialization_exception(mock_transport): client.start() +def test_mcp_tool_result_type(): + """Test that MCPToolResult extends ToolResult correctly.""" + # Test basic ToolResult functionality + result = MCPToolResult(status="success", toolUseId="test-123", content=[{"text": "Test message"}]) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert result["content"][0]["text"] == "Test message" + + # Test that structuredContent is optional + assert "structuredContent" not in result or result.get("structuredContent") is None + + # Test with structuredContent + result_with_structured = MCPToolResult( + status="success", toolUseId="test-456", content=[{"text": "Test message"}], structuredContent={"key": "value"} + ) + + assert result_with_structured["structuredContent"] == {"key": "value"} + + +def test_call_tool_sync_without_structured_content(mock_transport, mock_session): + """Test that call_tool_sync works correctly when no structured content is provided.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, + content=[mock_content], # No structuredContent + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # structuredContent should be None when not provided by MCP + assert result.get("structuredContent") is None + + def test_exception_when_future_not_running(): """Test exception handling when the future is not running.""" # Create a client.with a mock transport diff --git a/tests_integ/echo_server.py b/tests_integ/echo_server.py index d309607a8..52223792c 100644 --- a/tests_integ/echo_server.py +++ b/tests_integ/echo_server.py @@ -2,7 +2,7 @@ Echo Server for MCP Integration Testing This module implements a simple echo server using the Model Context Protocol (MCP). -It provides a basic tool that echoes back any input string, which is useful for +It provides basic tools that echo back input strings and structured content, which is useful for testing the MCP communication flow and validating that messages are properly transmitted between the client and server. @@ -15,6 +15,8 @@ $ python echo_server.py """ +from typing import Any, Dict + from mcp.server import FastMCP @@ -22,16 +24,22 @@ def start_echo_server(): """ Initialize and start the MCP echo server. - Creates a FastMCP server instance with a single 'echo' tool that returns - any input string back to the caller. The server uses stdio transport + Creates a FastMCP server instance with tools that return + input strings and structured content back to the caller. The server uses stdio transport for communication. + """ mcp = FastMCP("Echo Server") - @mcp.tool(description="Echos response back to the user") + @mcp.tool(description="Echos response back to the user", structured_output=False) def echo(to_echo: str) -> str: return to_echo + # FastMCP automatically constructs structured output schema from method signature + @mcp.tool(description="Echos response back with structured content", structured_output=True) + def echo_with_structured_content(to_echo: str) -> Dict[str, Any]: + return {"echoed": to_echo} + mcp.run(transport="stdio") diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 9163f625d..ebd4f5896 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -1,4 +1,5 @@ import base64 +import json import os import threading import time @@ -87,6 +88,24 @@ def test_mcp_client(): ] ) + tool_use_id = "test-structured-content-123" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "STRUCTURED_DATA_TEST"}, + ) + + # With the new MCPToolResult, structured content is in its own field + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": "STRUCTURED_DATA_TEST"} + + # Verify the result is an MCPToolResult (at runtime it's just a dict, but type-wise it should be MCPToolResult) + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "STRUCTURED_DATA_TEST"} + def test_can_reuse_mcp_client(): stdio_mcp_client = MCPClient( @@ -103,6 +122,64 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.asyncio +async def test_mcp_client_async_structured_content(): + """Test that async MCP client calls properly handle structured content. + + This test demonstrates how tools configure structured output: FastMCP automatically + constructs structured output schema from method signature when structured_output=True + is set in the @mcp.tool decorator. The return type annotation defines the structure + that appears in structuredContent field. + """ + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-async-structured-content-456" + result = await stdio_mcp_client.call_tool_async( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "ASYNC_STRUCTURED_TEST"}, + ) + + # Verify structured content is in its own field + assert "structuredContent" in result + # "result" nesting is not part of the MCP Structured Content specification, + # but rather a FastMCP implementation detail + assert result["structuredContent"]["result"] == {"echoed": "ASYNC_STRUCTURED_TEST"} + + # Verify basic MCPToolResult structure + assert result["status"] in ["success", "error"] + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "ASYNC_STRUCTURED_TEST"} + + +def test_mcp_client_without_structured_content(): + """Test that MCP client works correctly when tools don't return structured content.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-no-structured-content-789" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo", # This tool doesn't return structured content + arguments={"to_echo": "SIMPLE_ECHO_TEST"}, + ) + + # Verify no structured content when tool doesn't provide it + assert result.get("structuredContent") is None + + # Verify basic result structure + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", diff --git a/tests_integ/test_mcp_client_structured_content_with_hooks.py b/tests_integ/test_mcp_client_structured_content_with_hooks.py new file mode 100644 index 000000000..ca2468c48 --- /dev/null +++ b/tests_integ/test_mcp_client_structured_content_with_hooks.py @@ -0,0 +1,65 @@ +"""Integration test demonstrating hooks system with MCP client structured content tool. + +This test shows how to use the hooks system to capture and inspect tool invocation +results, specifically testing the echo_with_structured_content tool from echo_server. +""" + +import json + +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.experimental.hooks import AfterToolInvocationEvent +from strands.hooks import HookProvider, HookRegistry +from strands.tools.mcp.mcp_client import MCPClient + + +class StructuredContentHookProvider(HookProvider): + """Hook provider that captures structured content tool results.""" + + def __init__(self): + self.captured_result = None + + def register_hooks(self, registry: HookRegistry) -> None: + """Register callback for after tool invocation events.""" + registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation) + + def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None: + """Capture structured content tool results.""" + if event.tool_use["name"] == "echo_with_structured_content": + self.captured_result = event.result + + +def test_mcp_client_hooks_structured_content(): + """Test using hooks to inspect echo_with_structured_content tool result.""" + # Create hook provider to capture tool result + hook_provider = StructuredContentHookProvider() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and hook provider + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider]) + + # Test structured content functionality + test_data = "HOOKS_TEST_DATA" + agent(f"Use the echo_with_structured_content tool to echo: {test_data}") + + # Verify hook captured the tool result + assert hook_provider.captured_result is not None + result = hook_provider.captured_result + + # Verify basic result structure + assert result["status"] == "success" + assert len(result["content"]) == 1 + + # Verify structured content is present and correct + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": test_data} + + # Verify text content matches structured content + text_content = json.loads(result["content"][0]["text"]) + assert text_content == {"echoed": test_data} From b13c5c5492e7745acb86d23eb215acdce0120361 Mon Sep 17 00:00:00 2001 From: Ketan Suhaas Saichandran <55935983+Ketansuhaas@users.noreply.github.com> Date: Wed, 30 Jul 2025 08:59:29 -0400 Subject: [PATCH 12/41] feat(mcp): Add list_prompts, get_prompt methods (#160) Co-authored-by: ketan-clairyon Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 49 +++++++++++++ tests/strands/tools/mcp/test_mcp_client.py | 62 ++++++++++++++++ tests_integ/test_mcp_client.py | 83 +++++++++++++++++++--- 3 files changed, 184 insertions(+), 10 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 784636fd0..8c21baa4a 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -20,6 +20,7 @@ from mcp import ClientSession, ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent @@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) + def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + """Synchronously retrieves the list of available prompts from the MCP server. + + This method calls the asynchronous list_prompts method on the MCP session + and returns the raw ListPromptsResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListPromptsResult: The raw MCP response containing prompts and pagination info + """ + self._log_debug_with_thread("listing MCP prompts synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_prompts_async() -> ListPromptsResult: + return await self._background_thread_session.list_prompts(cursor=pagination_token) + + list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result() + self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts)) + for prompt in list_prompts_result.prompts: + self._log_debug_with_thread(prompt.name) + + return list_prompts_result + + def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult: + """Synchronously retrieves a prompt from the MCP server. + + Args: + prompt_id: The ID of the prompt to retrieve + args: Optional arguments to pass to the prompt + + Returns: + GetPromptResult: The prompt response from the MCP server + """ + self._log_debug_with_thread("getting MCP prompt synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _get_prompt_async() -> GetPromptResult: + return await self._background_thread_session.get_prompt(prompt_id, arguments=args) + + get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result() + self._log_debug_with_thread("received prompt from MCP server") + + return get_prompt_result + def call_tool_sync( self, tool_use_id: str, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 3d3792c71..bd88382cd 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -4,6 +4,7 @@ import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool @@ -404,3 +405,64 @@ def test_exception_when_future_not_running(): # Verify that set_exception was not called since the future was not running mock_future.set_exception.assert_not_called() + + +# Prompt Tests - Sync Methods + + +def test_list_prompts_sync(mock_transport, mock_session): + """Test that list_prompts_sync correctly retrieves prompts.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync() + + mock_session.list_prompts.assert_called_once_with(cursor=None) + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor is None + + +def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_prompts_sync correctly passes pagination token and returns next cursor.""" + mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1") + mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_prompts_sync(pagination_token="current_page_token") + + mock_session.list_prompts.assert_called_once_with(cursor="current_page_token") + assert len(result.prompts) == 1 + assert result.prompts[0].name == "test_prompt" + assert result.nextCursor == "next_page_token" + + +def test_list_prompts_sync_session_not_active(): + """Test that list_prompts_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_prompts_sync() + + +def test_get_prompt_sync(mock_transport, mock_session): + """Test that get_prompt_sync correctly retrieves a prompt.""" + mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt")) + mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.get_prompt_sync("test_prompt_id", {"key": "value"}) + + mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"}) + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert result.messages[0].content.text == "This is a test prompt" + + +def test_get_prompt_sync_session_not_active(): + """Test that get_prompt_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.get_prompt_sync("test_prompt_id", {}) diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index ebd4f5896..3de249435 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -18,18 +18,17 @@ from strands.types.tools import ToolUse -def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int): +def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int): """ - Initialize and start an MCP calculator server for integration testing. + Initialize and start a comprehensive MCP server for integration testing. - This function creates a FastMCP server instance that provides a simple - calculator tool for performing addition operations. The server uses - Server-Sent Events (SSE) transport for communication, making it accessible - over HTTP. + This function creates a FastMCP server instance that provides tools, prompts, + and resources all in one server for comprehensive testing. The server uses + Server-Sent Events (SSE) or streamable HTTP transport for communication. """ from mcp.server import FastMCP - mcp = FastMCP("Calculator Server", port=port) + mcp = FastMCP("Comprehensive MCP Server", port=port) @mcp.tool(description="Calculator tool which performs calculations") def calculator(x: int, y: int) -> int: @@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent: except Exception as e: print("Error while generating custom image: {}".format(e)) + # Prompts + @mcp.prompt(description="A greeting prompt template") + def greeting_prompt(name: str = "World") -> str: + return f"Hello, {name}! How are you today?" + + @mcp.prompt(description="A math problem prompt template") + def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str: + return f"Create a {difficulty} {operation} math problem and solve it step by step." + mcp.run(transport=transport) @@ -58,8 +66,9 @@ def test_mcp_client(): {'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]} """ # noqa: E501 + # Start comprehensive server with tools, prompts, and resources server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -68,8 +77,14 @@ def test_mcp_client(): stdio_mcp_client = MCPClient( lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) ) + with sse_mcp_client, stdio_mcp_client: - agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync()) + # Test Tools functionality + sse_tools = sse_mcp_client.list_tools_sync() + stdio_tools = stdio_mcp_client.list_tools_sync() + all_tools = sse_tools + stdio_tools + + agent = Agent(tools=all_tools) agent("add 1 and 2, then echo the result back to me") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) @@ -88,6 +103,43 @@ def test_mcp_client(): ] ) + # Test Prompts functionality + prompts_result = sse_mcp_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts + + prompt_names = [prompt.name for prompt in prompts_result.prompts] + assert "greeting_prompt" in prompt_names + assert "math_prompt" in prompt_names + + # Test get_prompt_sync with greeting prompt + greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Alice!" in prompt_text + assert "How are you today?" in prompt_text + + # Test get_prompt_sync with math prompt + math_result = sse_mcp_client.get_prompt_sync( + "math_prompt", {"operation": "multiplication", "difficulty": "medium"} + ) + assert len(math_result.messages) > 0 + math_text = math_result.messages[0].content.text + assert "multiplication" in math_text + assert "medium" in math_text + assert "step by step" in math_text + + # Test pagination support for prompts + prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None) + assert len(prompts_with_token.prompts) >= 0 + + # Test pagination support for tools (existing functionality) + tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None) + assert len(tools_with_token) >= 0 + + # TODO: Add resources testing when resources are implemented + # resources_result = sse_mcp_client.list_resources_sync() + # assert len(resources_result.resources) >= 0 + tool_use_id = "test-structured-content-123" result = stdio_mcp_client.call_tool_sync( tool_use_id=tool_use_id, @@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content(): reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", ) def test_streamable_http_mcp_client(): + """Test comprehensive MCP client with streamable HTTP transport.""" server_thread = threading.Thread( - target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True ) server_thread.start() time.sleep(2) # wait for server to startup completely @@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport: streamable_http_client = MCPClient(transport_callback) with streamable_http_client: + # Test tools agent = Agent(tools=streamable_http_client.list_tools_sync()) agent("add 1 and 2 using a calculator") tool_use_content_blocks = _messages_to_content_blocks(agent.messages) assert any([block["name"] == "calculator" for block in tool_use_content_blocks]) + # Test prompts + prompts_result = streamable_http_client.list_prompts_sync() + assert len(prompts_result.prompts) >= 2 + + greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"}) + assert len(greeting_result.messages) > 0 + prompt_text = greeting_result.messages[0].content.text + assert "Hello, Charlie!" in prompt_text + def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] From c5e4e51a0392fb921a280a8891de40398927fe98 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 30 Jul 2025 16:31:46 -0400 Subject: [PATCH 13/41] fix(event_loop): raise dedicated exception when encountering max tokens stop reason --- src/strands/event_loop/event_loop.py | 15 ++++++- src/strands/types/exceptions.py | 11 +++++ tests/strands/event_loop/test_event_loop.py | 48 ++++++++++++++++++++- tests_integ/test_max_tokens_reached.py | 18 ++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 tests_integ/test_max_tokens_reached.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..5b96dfc92 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + EventLoopMaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -216,6 +221,14 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield event return + elif stop_reason == "max_tokens": + raise EventLoopMaxTokensReachedException( + ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..14f76e945 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -18,6 +18,17 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class EventLoopMaxTokensReachedException(EventLoopException): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + pass + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3303b7282 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + EventLoopMaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,47 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(EventLoopMaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..b6f6b2857 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,18 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import EventLoopMaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=1) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(EventLoopMaxTokensReachedException): + agent("Tell me a story!") From 6703819d6b6cdedb7b08d92e028bb3deca6c4e78 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 30 Jul 2025 17:02:03 -0400 Subject: [PATCH 14/41] fix: update integ tests --- src/strands/event_loop/event_loop.py | 2 +- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/types/exceptions.py | 2 +- tests/strands/event_loop/test_event_loop.py | 9 ++++----- tests_integ/test_max_tokens_reached.py | 7 ++++--- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 5b96dfc92..16fefa5ac 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -226,7 +226,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" ) ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 14f76e945..7d9f1c6dc 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -18,7 +18,7 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) -class EventLoopMaxTokensReachedException(EventLoopException): +class EventLoopMaxTokensReachedException(Exception): """Exception raised when the model reaches its maximum token generation limit. This exception is raised when the model stops generating tokens because it has reached the maximum number of diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3303b7282..05b20ba01 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -22,7 +22,6 @@ from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, - EventLoopMaxTokensReachedException, ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -568,7 +567,7 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason raises EventLoopMaxTokensReachedException.""" # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 model.stream.return_value = agenerator( @@ -585,8 +584,8 @@ async def test_event_loop_cycle_max_tokens_exception( ] ) - # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(EventLoopMaxTokensReachedException) as exc_info: + # Call event_loop_cycle, expecting it to raise EventLoopMaxTokensReachedException + with pytest.raises(EventLoopException) as exc_info: stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -597,7 +596,7 @@ async def test_event_loop_cycle_max_tokens_exception( expected_message = ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" ) assert str(exc_info.value) == expected_message diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index b6f6b2857..1bf75f136 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,8 +1,7 @@ -import pytest from strands import Agent, tool from strands.models.bedrock import BedrockModel -from strands.types.exceptions import EventLoopMaxTokensReachedException +from strands.types.exceptions import EventLoopException, EventLoopMaxTokensReachedException @tool @@ -14,5 +13,7 @@ def test_context_window_overflow(): model = BedrockModel(max_tokens=1) agent = Agent(model=model, tools=[story_tool]) - with pytest.raises(EventLoopMaxTokensReachedException): + try: agent("Tell me a story!") + except EventLoopException as e: + assert isinstance(e.original_exception, EventLoopMaxTokensReachedException) From 3d526f2e254d38bb83b8ec85af56e79e4e1fe33f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=BF=E3=81=AE=E3=82=8B=E3=82=93?= <74597894+minorun365@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:40:25 +0900 Subject: [PATCH 15/41] fix(deps): pin a2a-sdk>=0.2.16 to resolve #572 (#581) Co-authored-by: Jeremiah --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 095a38cb0..cdf68e01f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ sagemaker = [ ] a2a = [ + "a2a-sdk>=0.2.16,<1.0.0", "a2a-sdk[sql]>=0.2.11,<1.0.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", @@ -321,4 +322,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] \ No newline at end of file +] From c94b74e75236dcbac0ffdb438f3a4a9ff59cda5f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 10:50:40 -0400 Subject: [PATCH 16/41] fix: rename exception message, add to exception, move earlier in cycle --- src/strands/event_loop/event_loop.py | 29 ++++++++++++++------- src/strands/types/exceptions.py | 14 ++++++++-- tests/strands/event_loop/test_event_loop.py | 13 ++++++--- tests_integ/test_max_tokens_reached.py | 7 +++-- 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 16fefa5ac..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -31,7 +31,7 @@ from ..types.exceptions import ( ContextWindowOverflowException, EventLoopException, - EventLoopMaxTokensReachedException, + MaxTokensReachedException, ModelThrottledException, ) from ..types.streaming import Metrics, StopReason @@ -192,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -221,14 +237,6 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield event return - elif stop_reason == "max_tokens": - raise EventLoopMaxTokensReachedException( - ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" - ) - ) # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) @@ -244,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 7d9f1c6dc..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,7 +20,7 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) -class EventLoopMaxTokensReachedException(Exception): +class MaxTokensReachedException(Exception): """Exception raised when the model reaches its maximum token generation limit. This exception is raised when the model stops generating tokens because it has reached the maximum number of @@ -26,7 +28,15 @@ class EventLoopMaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - pass + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) class ContextWindowOverflowException(Exception): diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 05b20ba01..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -22,6 +22,7 @@ from strands.types.exceptions import ( ContextWindowOverflowException, EventLoopException, + MaxTokensReachedException, ModelThrottledException, ) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -567,7 +568,7 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises EventLoopMaxTokensReachedException.""" + """Test that max_tokens stop reason raises MaxTokensReachedException.""" # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 model.stream.return_value = agenerator( @@ -584,8 +585,8 @@ async def test_event_loop_cycle_max_tokens_exception( ] ) - # Call event_loop_cycle, expecting it to raise EventLoopMaxTokensReachedException - with pytest.raises(EventLoopException) as exc_info: + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -596,10 +597,14 @@ async def test_event_loop_cycle_max_tokens_exception( expected_message = ( "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#eventloopmaxtokensreachedexception" + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" ) assert str(exc_info.value) == expected_message + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 1bf75f136..519cf62c2 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,7 +1,8 @@ +import pytest from strands import Agent, tool from strands.models.bedrock import BedrockModel -from strands.types.exceptions import EventLoopException, EventLoopMaxTokensReachedException +from strands.types.exceptions import MaxTokensReachedException @tool @@ -13,7 +14,5 @@ def test_context_window_overflow(): model = BedrockModel(max_tokens=1) agent = Agent(model=model, tools=[story_tool]) - try: + with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - except EventLoopException as e: - assert isinstance(e.original_exception, EventLoopMaxTokensReachedException) From 36dd0f9304ba0daa4fceffef614ff91400fcb23a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:53:04 -0400 Subject: [PATCH 17/41] Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg --- tests_integ/test_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 519cf62c2..1b822dcba 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -11,7 +11,7 @@ def story_tool(story: str) -> str: def test_context_window_overflow(): - model = BedrockModel(max_tokens=1) + model = BedrockModel(max_tokens=100) agent = Agent(model=model, tools=[story_tool]) with pytest.raises(MaxTokensReachedException): From e04c73d85d86dde5d9e415ae2ef693aa9a55da56 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:53:11 -0400 Subject: [PATCH 18/41] Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg --- tests_integ/test_max_tokens_reached.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 1b822dcba..5f7e5584c 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -16,3 +16,5 @@ def test_context_window_overflow(): with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") + + assert len(agent.messages) == 1 From cca2f86a3f7a1d22cfa8cf59ffa0029943a0efa7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 31 Jul 2025 14:57:19 -0400 Subject: [PATCH 19/41] linting --- tests_integ/test_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 5f7e5584c..d9c2817b3 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -16,5 +16,5 @@ def test_context_window_overflow(): with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - + assert len(agent.messages) == 1 From b56a4ff32e93dd74a10c8895cd68528091e88f1b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 1 Aug 2025 09:42:35 -0400 Subject: [PATCH 20/41] chore: pin a2a to a minor version while it is still in beta (#586) --- pyproject.toml | 6 +++--- src/strands/multiagent/a2a/executor.py | 2 +- tests/strands/multiagent/a2a/test_executor.py | 16 ++++++++-------- tests/strands/multiagent/a2a/test_server.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cdf68e01f..586a956af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,8 @@ sagemaker = [ ] a2a = [ - "a2a-sdk>=0.2.16,<1.0.0", - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk>=0.3.0,<0.4.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -143,7 +143,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index d65c64aff..5bf9cbfe9 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -61,7 +61,7 @@ async def execute( task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) - updater = TaskUpdater(event_queue, task.id, task.contextId) + updater = TaskUpdater(event_queue, task.id, task.context_id) try: await self._execute_streaming(context, updater) diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index a956cb769..77645fc73 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -36,7 +36,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -65,7 +65,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -95,7 +95,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -125,7 +125,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -156,7 +156,7 @@ async def mock_stream(user_input): mock_request_context.current_task = None with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: - mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id") + mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id") await executor.execute(mock_request_context, mock_event_queue) @@ -180,7 +180,7 @@ async def test_execute_streaming_mode_handles_agent_exception( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task with pytest.raises(ServerError): @@ -210,7 +210,7 @@ async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_req # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater @@ -235,7 +235,7 @@ async def test_handle_agent_result_with_result_but_no_message( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index fc76b5f1d..a3b47581c 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -87,8 +87,8 @@ def test_public_agent_card(mock_strands_agent): assert card.description == "A test agent for unit testing" assert card.url == "http://0.0.0.0:9000/" assert card.version == "0.0.1" - assert card.defaultInputModes == ["text"] - assert card.defaultOutputModes == ["text"] + assert card.default_input_modes == ["text"] + assert card.default_output_modes == ["text"] assert card.skills == [] assert card.capabilities == a2a_agent.capabilities From 8b1de4d4cc4f8adc5386bb1a134aabf96e698cdd Mon Sep 17 00:00:00 2001 From: Laith Al-Saadoon <9553966+theagenticguy@users.noreply.github.com> Date: Fri, 1 Aug 2025 09:23:25 -0500 Subject: [PATCH 21/41] fix: uses new a2a snake_case for lints to pass (#591) --- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/session/file_session_manager.py | 3 ++- src/strands/session/s3_session_manager.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index b32cb00e6..fec2f0761 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -23,6 +23,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): """File-based session manager for local filesystem storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ├── session.json # Session metadata @@ -32,7 +33,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): └── messages/ ├── message_.json └── message_.json - + ``` """ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 8f8423828..0cc0a68c1 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -24,6 +24,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): """S3-based session manager for cloud storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ├── session.json # Session metadata @@ -33,7 +34,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): └── messages/ ├── message_.json └── message_.json - + ``` """ def __init__( From c85464c45715a9d2ef3f9377f59f9e970ee81cf9 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 1 Aug 2025 10:37:17 -0400 Subject: [PATCH 22/41] =?UTF-8?q?fix(event=5Floop):=20raise=20dedicated=20?= =?UTF-8?q?exception=20when=20encountering=20max=20toke=E2=80=A6=20(#576)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(event_loop): raise dedicated exception when encountering max tokens stop reason * fix: update integ tests * fix: rename exception message, add to exception, move earlier in cycle * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg * Update tests_integ/test_max_tokens_reached.py Co-authored-by: Nick Clegg * linting --------- Co-authored-by: Nick Clegg --- src/strands/event_loop/event_loop.py | 26 ++++++++++- src/strands/types/exceptions.py | 21 +++++++++ tests/strands/event_loop/test_event_loop.py | 52 ++++++++++++++++++++- tests_integ/test_max_tokens_reached.py | 20 ++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 tests_integ/test_max_tokens_reached.py diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class MaxTokensReachedException(Exception): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..d9c2817b3 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,20 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import MaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=100) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(MaxTokensReachedException): + agent("Tell me a story!") + + assert len(agent.messages) == 1 From 2e2d4df9f6d7d98993f65fb40540663c74f7f0ea Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 17:47:25 -0400 Subject: [PATCH 23/41] feat: add builtin hook provider to address max tokens reached truncation --- src/strands/agent/agent.py | 18 +++- src/strands/experimental/hooks/__init__.py | 2 + src/strands/experimental/hooks/events.py | 26 +++++ .../experimental/hooks/providers/__init__.py | 0 .../correct_tool_use_hook_provider.py | 95 +++++++++++++++++++ tests/strands/agent/test_agent_hooks.py | 55 ++++++++++- tests_integ/test_max_tokens_reached.py | 18 ++++ 7 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 src/strands/experimental/hooks/providers/__init__.py create mode 100644 src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 111509e3a..c86b64ff3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool +from ..experimental.hooks.events import EventLoopFailureEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -582,7 +583,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A ) async for event in events: yield event - + return except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) @@ -591,6 +592,21 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A if self._session_manager: self._session_manager.sync_agent(self) + # If the events have been handled, attempt to restart the event loop in the now-healthy state + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event + except Exception as e: + """ + Catch all other exceptions which are unrecoverable without intervention. + Reraise exception if EventLoopFailureEvent.should_continue is false + """ + event_loop_failure_event = EventLoopFailureEvent(agent=self, exception=e) + self.hooks.invoke_callbacks(event_loop_failure_event) + if not event_loop_failure_event.should_continue_loop: + raise + + # If the events have been handled, attempt to restart the event loop in the now-healthy state events = self._execute_event_loop_cycle(invocation_state) async for event in events: yield event diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 098d4cf0d..384d8a505 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -5,6 +5,7 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + EventLoopFailureEvent, ) __all__ = [ @@ -12,4 +13,5 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", + "EventLoopFailureEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d03e65d85..128882821 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -121,3 +121,29 @@ class ModelStopResponse: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +@dataclass +class EventLoopFailureEvent(HookEvent): + """Event triggered when the event loop encounters a failure. + + This event is fired when an exception occurs during event loop execution, + allowing hook providers to handle the failure or perform recovery actions. + + Attributes: + exception: The exception that caused the event loop failure. + should_continue_loop: Flag that hooks can set to True to indicate they have + handled the exception and the event loop should continue normally. + + Warning: + Setting should_continue_loop=True without properly addressing the underlying + cause of the exception may result in infinite loops if the same failure + condition persists. Hooks should implement appropriate error handling, + retry limits, or state modifications to prevent recurring failures. + """ + + exception: Exception + should_continue_loop: bool = False + + def _can_write(self, name: str) -> bool: + return name == "should_continue_loop" diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py new file mode 100644 index 000000000..96020cc58 --- /dev/null +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -0,0 +1,95 @@ +import logging +from typing import Any + +from src.strands.hooks import MessageAddedEvent +from src.strands.types.tools import ToolUse +from strands.experimental.hooks.events import EventLoopFailureEvent +from strands.hooks import HookProvider, HookRegistry +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import MaxTokensReachedException + +logger = logging.getLogger(__name__) + + +class CorrectToolUseHookProvider(HookProvider): + """Hook provider that handles MaxTokensReachedException by fixing incomplete tool uses. + + This hook provider is triggered when a MaxTokensReachedException occurs during event loop execution. + When the model's response is truncated due to token limits, tool use entries may be incomplete + or missing required fields (name, input, toolUseId). + + The provider fixes these issues by: + + 1. Inspecting each content block in the incomplete message for invalid tool uses + 2. Replacing incomplete tool use blocks with informative text messages + 3. Preserving valid content blocks in the corrected message + 4. Adding the corrected message to the agent's conversation history + 5. Allowing the event loop to continue processing + + If a tool use is invalid for unknown reasons, not due to empty fields, the hook + allows the original exception to propagate to avoid unsafe recovery attempts. + """ + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook to handle EventLoopFailureEvent for MaxTokensReachedException.""" + registry.add_callback(EventLoopFailureEvent, self._handle_max_tokens_reached) + + def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: + """Handle MaxTokensReachedException by cleaning up orphaned tool uses and allowing continuation.""" + if not isinstance(event.exception, MaxTokensReachedException): + return + + logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + + incomplete_message: Message = event.exception.incomplete_message + valid_content: list[ContentBlock] = [] + + for i, content in enumerate(incomplete_message["content"]): + tool_use: ToolUse = content.get("toolUse") + if not tool_use: + valid_content.append(content) + logger.debug(f"Content block {i}: Valid non-tool content preserved") + continue + + """ + Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented + using pydantic, we inspect each field. + """ + tool_name = tool_use.get("name", "") + tool_input = tool_use.get("input") + tool_use_id = tool_use.get("toolUseId") + + if not (tool_name and tool_input and tool_use_id): + """ + If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not + being populated it correctly. + + It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure + on the next iteration. + """ + logger.warning( + f"Invalid tool use found at content block {i}: tool_name='{tool_name}', " + f"Replacing with error message due to max_tokens truncation." + ) + + valid_content.append( + { + "text": f"The selected tool {tool_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + else: + # Tool use is invalid for an unknown reason. Cannot safely recover, so allow exception to propagate + logger.debug( + f"Tool use at content block {i} appears complete but is still invalid. " + f"tool_name='{tool_name}', tool_use_id='{tool_use_id}'. " + f"Cannot safely recover - allowing exception to propagate." + ) + return + + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} + event.agent.messages.append(valid_message) + event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) + event.should_continue_loop = True + + logger.info("MaxTokensReachedException handled successfully - continuing event loop") diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index cd89fbc7a..e71c0aa94 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,15 +1,17 @@ -from unittest.mock import ANY, Mock +from unittest.mock import ANY, Mock, patch import pytest from pydantic import BaseModel import strands +from src.strands.types.exceptions import MaxTokensReachedException from strands import Agent from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + EventLoopFailureEvent, ) from strands.hooks import ( AfterInvocationEvent, @@ -35,6 +37,7 @@ def hook_provider(): BeforeModelInvocationEvent, AfterModelInvocationEvent, MessageAddedEvent, + EventLoopFailureEvent, ] ) @@ -292,3 +295,53 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 1 + + +def test_event_loop_failure_event_exception_rethrown_when_not_handled(agent, hook_provider): + """Test that EventLoopFailureEvent is triggered and exceptions are re-thrown when not handled.""" + + # Mock event_loop_cycle to raise a general exception (not ContextWindowOverflowException) + with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: + mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) + + with pytest.raises(MaxTokensReachedException): + agent("test message") + length, events = hook_provider.get_events() + failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] + + assert len(failure_events) == 1 + assert isinstance(failure_events[0].exception, MaxTokensReachedException) + assert failure_events[0].should_continue_loop is False + + +def test_event_loop_failure_event_exception_handled_by_hook(agent, hook_provider): + """Test that EventLoopFailureEvent allows hooks to handle exceptions and continue execution.""" + + first_call = True + + def hook_callback(event: EventLoopFailureEvent): + nonlocal first_call + # Hook handles the exception by setting should_continue_loop to True + event.should_continue_loop = first_call + first_call = False + + agent.hooks.add_callback(EventLoopFailureEvent, hook_callback) + + # Mock event_loop_cycle to raise a general exception + with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: + mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) + + # Should NOT raise exception due to hook handling on the first failure + with pytest.raises(MaxTokensReachedException): + agent("test message") + + length, events = hook_provider.get_events() + failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] + + assert len(failure_events) == 2 + + assert isinstance(failure_events[0].exception, MaxTokensReachedException) + assert failure_events[0].should_continue_loop is True + + assert isinstance(failure_events[1].exception, MaxTokensReachedException) + assert failure_events[1].should_continue_loop is False diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d9c2817b3..7c7a48973 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -1,12 +1,20 @@ +import logging + import pytest from strands import Agent, tool +from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException +logger = logging.getLogger(__name__) + @tool def story_tool(story: str) -> str: + """ + Tool that writes a story that is minimum 50,000 lines long. + """ return story @@ -18,3 +26,13 @@ def test_context_window_overflow(): agent("Tell me a story!") assert len(agent.messages) == 1 + + +def test_max_tokens_reached_with_hook_provider(): + """Test that MaxTokensReachedException can be handled by a hook provider.""" + model = BedrockModel(max_tokens=100) + hook_provider = CorrectToolUseHookProvider() + agent = Agent(model=model, tools=[story_tool], hooks=[hook_provider]) + + # This should NOT raise an exception because the hook handles it + agent("Tell me a story!") From 447d147ee001288dcec224d4e3389b71a7f0dd2c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:23:40 -0400 Subject: [PATCH 24/41] tests: modify integ test to inspect message history --- .../experimental/hooks/providers/__init__.py | 5 + .../correct_tool_use_hook_provider.py | 40 ++++---- .../experimental/hooks/providers/__init__.py | 1 + .../test_correct_tool_use_hook_provider.py | 99 +++++++++++++++++++ tests_integ/test_max_tokens_reached.py | 13 +++ 5 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 tests/strands/experimental/hooks/providers/__init__.py create mode 100644 tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py index e69de29bb..5b74733b8 100644 --- a/src/strands/experimental/hooks/providers/__init__.py +++ b/src/strands/experimental/hooks/providers/__init__.py @@ -0,0 +1,5 @@ +"""Hook providers for experimental Strands Agents functionality. + +This package contains experimental hook providers that extend the core agent functionality +with additional capabilities. +""" diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py index 96020cc58..3c9ef0803 100644 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -1,3 +1,11 @@ +"""Hook provider for correcting incomplete tool uses due to token limits. + +This module provides the CorrectToolUseHookProvider class, which handles scenarios where +the model's response is truncated due to maximum token limits, resulting in incomplete +or malformed tool use entries. The provider automatically corrects these issues to allow +the agent conversation to continue gracefully. +""" + import logging from typing import Any @@ -42,24 +50,25 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = event.exception.incomplete_message - valid_content: list[ContentBlock] = [] - for i, content in enumerate(incomplete_message["content"]): + if not incomplete_message["content"]: + # Cannot correct invalid content block if content is empty + return + + valid_content: list[ContentBlock] = [] + for content in incomplete_message["content"]: tool_use: ToolUse = content.get("toolUse") if not tool_use: valid_content.append(content) - logger.debug(f"Content block {i}: Valid non-tool content preserved") continue """ Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented using pydantic, we inspect each field. """ - tool_name = tool_use.get("name", "") - tool_input = tool_use.get("input") - tool_use_id = tool_use.get("toolUseId") - - if not (tool_name and tool_input and tool_use_id): + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): """ If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not being populated it correctly. @@ -67,29 +76,22 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure on the next iteration. """ + display_name = tool_name if tool_name else "" logger.warning( - f"Invalid tool use found at content block {i}: tool_name='{tool_name}', " - f"Replacing with error message due to max_tokens truncation." + "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name ) valid_content.append( { - "text": f"The selected tool {tool_name}'s tool use was incomplete due " + "text": f"The selected tool {display_name}'s tool use was incomplete due " f"to maximum token limits being reached." } ) else: - # Tool use is invalid for an unknown reason. Cannot safely recover, so allow exception to propagate - logger.debug( - f"Tool use at content block {i} appears complete but is still invalid. " - f"tool_name='{tool_name}', tool_use_id='{tool_use_id}'. " - f"Cannot safely recover - allowing exception to propagate." - ) + # ToolUse was invalid for an unknown reason. Cannot correct, return and allow exception to propagate up. return valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} event.agent.messages.append(valid_message) event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) event.should_continue_loop = True - - logger.info("MaxTokensReachedException handled successfully - continuing event loop") diff --git a/tests/strands/experimental/hooks/providers/__init__.py b/tests/strands/experimental/hooks/providers/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/strands/experimental/hooks/providers/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py new file mode 100644 index 000000000..93d672ab2 --- /dev/null +++ b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py @@ -0,0 +1,99 @@ +"""Unit tests for CorrectToolUseHookProvider.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.events import EventLoopFailureEvent +from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider +from strands.hooks import HookRegistry +from strands.types.content import Message +from strands.types.exceptions import MaxTokensReachedException + + +@pytest.fixture +def hook_provider(): + """Create a CorrectToolUseHookProvider instance.""" + return CorrectToolUseHookProvider() + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with messages and hooks.""" + agent = Mock() + agent.messages = [] + agent.hooks = Mock() + return agent + + +@pytest.fixture +def mock_registry(): + """Create a mock hook registry.""" + return Mock(spec=HookRegistry) + + +def test_register_hooks(hook_provider, mock_registry): + """Test that the hook provider registers the correct callback.""" + hook_provider.register_hooks(mock_registry) + + mock_registry.add_callback.assert_called_once_with(EventLoopFailureEvent, hook_provider._handle_max_tokens_reached) + + +def test_handle_non_max_tokens_exception(hook_provider, mock_agent): + """Test that non-MaxTokensReachedException events are ignored.""" + other_exception = ValueError("Some other error") + event = EventLoopFailureEvent(agent=mock_agent, exception=other_exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should not modify the agent or event + assert len(mock_agent.messages) == 0 + assert not event.should_continue_loop + mock_agent.hooks.invoke_callbacks.assert_not_called() + + +@pytest.mark.parametrize( + "incomplete_tool_use,expected_tool_name", + [ + ({"toolUseId": "tool-123", "input": {"param": "value"}}, ""), # Missing name + ({"name": "test_tool", "toolUseId": "tool-123"}, "test_tool"), # Missing input + ({"name": "test_tool", "input": {}, "toolUseId": "tool-123"}, "test_tool"), # Empty input + ({"name": "test_tool", "input": {"param": "value"}}, "test_tool"), # Missing toolUseId + ], +) +def test_handle_max_tokens_with_incomplete_tool_use(hook_provider, mock_agent, incomplete_tool_use, expected_tool_name): + """Test handling various incomplete tool use scenarios.""" + incomplete_message: Message = { + "role": "user", # Test role preservation + "content": [{"text": "I'll use a tool"}, {"toolUse": incomplete_tool_use}], + } + + exception = MaxTokensReachedException("Max tokens reached", incomplete_message) + event = EventLoopFailureEvent(agent=mock_agent, exception=exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should add corrected message with error text and preserve role + assert len(mock_agent.messages) == 1 + added_message = mock_agent.messages[0] + assert added_message["role"] == "user" # Role preserved + assert len(added_message["content"]) == 2 + assert added_message["content"][0]["text"] == "I'll use a tool" + assert f"The selected tool {expected_tool_name}'s tool use was incomplete" in added_message["content"][1]["text"] + assert "maximum token limits being reached" in added_message["content"][1]["text"] + + assert event.should_continue_loop + + +def test_handle_max_tokens_with_no_content(hook_provider, mock_agent): + """Test handling message with no content blocks.""" + incomplete_message: Message = {"role": "assistant", "content": []} + + exception = MaxTokensReachedException("Max tokens reached", incomplete_message) + event = EventLoopFailureEvent(agent=mock_agent, exception=exception) + + hook_provider._handle_max_tokens_reached(event) + + # Should add empty message and continue + assert len(mock_agent.messages) == 0 + assert not event.should_continue_loop diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 7c7a48973..6bad70636 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -36,3 +36,16 @@ def test_max_tokens_reached_with_hook_provider(): # This should NOT raise an exception because the hook handles it agent("Tell me a story!") + + # Validate that at least one message contains the incomplete tool use error message + expected_text = "tool use was incomplete due to maximum token limits being reached" + all_text_content = [ + content_block["text"] + for message in agent.messages + for content_block in message.get("content", []) + if "text" in content_block + ] + + assert any(expected_text in text for text in all_text_content), ( + f"Expected to find message containing '{expected_text}' in agent messages" + ) From 564895d5e04ec6e46ebec04bfa5421fc0fbbcdce Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:31:25 -0400 Subject: [PATCH 25/41] fix: fix linting errors --- tests/strands/agent/test_agent_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e71c0aa94..66eb86808 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -4,7 +4,6 @@ from pydantic import BaseModel import strands -from src.strands.types.exceptions import MaxTokensReachedException from strands import Agent from strands.experimental.hooks import ( AfterModelInvocationEvent, @@ -20,6 +19,7 @@ MessageAddedEvent, ) from strands.types.content import Messages +from strands.types.exceptions import MaxTokensReachedException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider From 2f118fb03b7faba54850981e99c7bf1a76785bca Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 4 Aug 2025 18:39:28 -0400 Subject: [PATCH 26/41] fix: linting --- .../hooks/providers/correct_tool_use_hook_provider.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py index 3c9ef0803..c8b12d98d 100644 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py @@ -9,12 +9,11 @@ import logging from typing import Any -from src.strands.hooks import MessageAddedEvent -from src.strands.types.tools import ToolUse from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.hooks import HookProvider, HookRegistry +from strands.hooks import HookProvider, HookRegistry, MessageAddedEvent from strands.types.content import ContentBlock, Message from strands.types.exceptions import MaxTokensReachedException +from strands.types.tools import ToolUse logger = logging.getLogger(__name__) @@ -57,7 +56,7 @@ def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: valid_content: list[ContentBlock] = [] for content in incomplete_message["content"]: - tool_use: ToolUse = content.get("toolUse") + tool_use: ToolUse | None = content.get("toolUse") if not tool_use: valid_content.append(content) continue From e5fc51a432bdc89f112b3b8fc55b3c1e7b4d063a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 5 Aug 2025 18:30:08 -0400 Subject: [PATCH 27/41] refactor: switch from hook approach to conversation manager --- src/strands/agent/agent.py | 38 ++-- .../conversation_manager.py | 16 ++ .../null_conversation_manager.py | 15 +- .../sliding_window_conversation_manager.py | 13 +- .../summarizing_conversation_manager.py | 13 +- .../token_limit_recovery.py | 66 ++++++ src/strands/experimental/hooks/__init__.py | 2 - src/strands/experimental/hooks/events.py | 26 --- .../experimental/hooks/providers/__init__.py | 5 - .../correct_tool_use_hook_provider.py | 96 --------- .../agent/conversation_manager/__init__.py | 1 + .../test_token_limit_recovery.py | 200 ++++++++++++++++++ tests/strands/agent/test_agent.py | 68 +++++- tests/strands/agent/test_agent_hooks.py | 55 +---- .../agent/test_conversation_manager.py | 92 +++++++- .../experimental/hooks/providers/__init__.py | 1 - .../test_correct_tool_use_hook_provider.py | 99 --------- tests_integ/test_max_tokens_reached.py | 9 +- 18 files changed, 497 insertions(+), 318 deletions(-) create mode 100644 src/strands/agent/conversation_manager/token_limit_recovery.py delete mode 100644 src/strands/experimental/hooks/providers/__init__.py delete mode 100644 src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py create mode 100644 tests/strands/agent/conversation_manager/__init__.py create mode 100644 tests/strands/agent/conversation_manager/test_token_limit_recovery.py delete mode 100644 tests/strands/experimental/hooks/providers/__init__.py delete mode 100644 tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c86b64ff3..e258cb324 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,7 +20,6 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool -from ..experimental.hooks.events import EventLoopFailureEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, @@ -38,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -54,13 +53,14 @@ T = TypeVar("T", bound=BaseModel) -# Sentinel class and object to distinguish between explicit None and default parameter value +# Sentinel classes to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: """Sentinel class to distinguish between explicit None and default parameter value.""" pass + _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" @@ -247,7 +247,7 @@ def __init__( state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry - Defaults to None. + Defaults to set of if None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. """ @@ -587,29 +587,17 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) + except MaxTokensReachedException as e: + # Recover conversation state after token limit exceeded, then continue with next cycle + self.conversation_manager.handle_token_limit_reached(self, e=e) - # Sync agent after reduce_context to keep conversation_manager_state up to date in the session - if self._session_manager: - self._session_manager.sync_agent(self) - - # If the events have been handled, attempt to restart the event loop in the now-healthy state - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event - except Exception as e: - """ - Catch all other exceptions which are unrecoverable without intervention. - Reraise exception if EventLoopFailureEvent.should_continue is false - """ - event_loop_failure_event = EventLoopFailureEvent(agent=self, exception=e) - self.hooks.invoke_callbacks(event_loop_failure_event) - if not event_loop_failure_event.should_continue_loop: - raise + # Sync agent after handling exception to keep conversation_manager_state up to date in the session + if self._session_manager: + self._session_manager.sync_agent(self) - # If the events have been handled, attempt to restart the event loop in the now-healthy state - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event def _record_tool_execution( self, diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..c2899209b 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional from ...types.content import Message +from ...types.exceptions import MaxTokensReachedException if TYPE_CHECKING: from ...agent.agent import Agent @@ -86,3 +87,18 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs **kwargs: Additional keyword arguments for future extensibility. """ pass + + @abstractmethod + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Called when MaxTokensReachedException is thrown to recover conversation state. + + This method should implement recovery strategies when the token limit is exceeded and the message array + may be in a broken state. It is called outside the event loop to apply default recovery mechanisms. + + Args: + agent: The agent whose conversation state will be recovered. + This list is modified in-place. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + pass diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..29fa1c442 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager @@ -44,3 +44,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs raise e else: raise ContextWindowOverflowException("Context window overflowed!") + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Does not handle token limit recovery and raises the exception. + + Args: + agent: The agent whose conversation state will remain unmodified. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + + Raises: + e: The provided exception. + """ + raise e diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..f96dbff27 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -7,8 +7,9 @@ from ...agent.agent import Agent from ...types.content import Messages -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager +from .token_limit_recovery import recover_from_max_tokens_reached logger = logging.getLogger(__name__) @@ -177,3 +178,13 @@ def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[i return idx return None + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply sliding window strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + recover_from_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 60e832215..fe0d13fa4 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -6,8 +6,9 @@ from typing_extensions import override from ...types.content import Message -from ...types.exceptions import ContextWindowOverflowException +from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager +from .token_limit_recovery import recover_from_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -250,3 +251,13 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point + + def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply summarization strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + recover_from_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/token_limit_recovery.py new file mode 100644 index 000000000..a0935f3a3 --- /dev/null +++ b/src/strands/agent/conversation_manager/token_limit_recovery.py @@ -0,0 +1,66 @@ +"""Shared utility for handling token limit recovery in conversation managers.""" + +import logging +from typing import TYPE_CHECKING + +from ...types.content import ContentBlock, Message +from ...types.exceptions import MaxTokensReachedException +from ...types.tools import ToolUse + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + + +def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: + """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. + + This function fixes incomplete tool uses that may occur when the model's response is truncated + due to token limits. It: + + 1. Inspects each content block in the incomplete message for invalid tool uses + 2. Replaces incomplete tool use blocks with informative text messages + 3. Preserves valid content blocks in the corrected message + 4. Adds the corrected message to the agent's conversation history + + Args: + agent: The agent whose conversation will be updated with the corrected message. + exception: The MaxTokensReachedException containing the incomplete message. + """ + logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + + incomplete_message: Message = exception.incomplete_message + + if not incomplete_message["content"]: + # Cannot correct invalid content block if content is empty + return + + valid_content: list[ContentBlock] = [] + for content in incomplete_message["content"]: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): + # Tool use is incomplete due to max_tokens truncation + display_name = tool_name if tool_name else "" + logger.warning( + "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name + ) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + else: + # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying + return + + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} + agent.messages.append(valid_message) diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 384d8a505..098d4cf0d 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -5,7 +5,6 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, - EventLoopFailureEvent, ) __all__ = [ @@ -13,5 +12,4 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", - "EventLoopFailureEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 128882821..d03e65d85 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -121,29 +121,3 @@ class ModelStopResponse: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True - - -@dataclass -class EventLoopFailureEvent(HookEvent): - """Event triggered when the event loop encounters a failure. - - This event is fired when an exception occurs during event loop execution, - allowing hook providers to handle the failure or perform recovery actions. - - Attributes: - exception: The exception that caused the event loop failure. - should_continue_loop: Flag that hooks can set to True to indicate they have - handled the exception and the event loop should continue normally. - - Warning: - Setting should_continue_loop=True without properly addressing the underlying - cause of the exception may result in infinite loops if the same failure - condition persists. Hooks should implement appropriate error handling, - retry limits, or state modifications to prevent recurring failures. - """ - - exception: Exception - should_continue_loop: bool = False - - def _can_write(self, name: str) -> bool: - return name == "should_continue_loop" diff --git a/src/strands/experimental/hooks/providers/__init__.py b/src/strands/experimental/hooks/providers/__init__.py deleted file mode 100644 index 5b74733b8..000000000 --- a/src/strands/experimental/hooks/providers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Hook providers for experimental Strands Agents functionality. - -This package contains experimental hook providers that extend the core agent functionality -with additional capabilities. -""" diff --git a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py b/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py deleted file mode 100644 index c8b12d98d..000000000 --- a/src/strands/experimental/hooks/providers/correct_tool_use_hook_provider.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Hook provider for correcting incomplete tool uses due to token limits. - -This module provides the CorrectToolUseHookProvider class, which handles scenarios where -the model's response is truncated due to maximum token limits, resulting in incomplete -or malformed tool use entries. The provider automatically corrects these issues to allow -the agent conversation to continue gracefully. -""" - -import logging -from typing import Any - -from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.hooks import HookProvider, HookRegistry, MessageAddedEvent -from strands.types.content import ContentBlock, Message -from strands.types.exceptions import MaxTokensReachedException -from strands.types.tools import ToolUse - -logger = logging.getLogger(__name__) - - -class CorrectToolUseHookProvider(HookProvider): - """Hook provider that handles MaxTokensReachedException by fixing incomplete tool uses. - - This hook provider is triggered when a MaxTokensReachedException occurs during event loop execution. - When the model's response is truncated due to token limits, tool use entries may be incomplete - or missing required fields (name, input, toolUseId). - - The provider fixes these issues by: - - 1. Inspecting each content block in the incomplete message for invalid tool uses - 2. Replacing incomplete tool use blocks with informative text messages - 3. Preserving valid content blocks in the corrected message - 4. Adding the corrected message to the agent's conversation history - 5. Allowing the event loop to continue processing - - If a tool use is invalid for unknown reasons, not due to empty fields, the hook - allows the original exception to propagate to avoid unsafe recovery attempts. - """ - - def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: - """Register hook to handle EventLoopFailureEvent for MaxTokensReachedException.""" - registry.add_callback(EventLoopFailureEvent, self._handle_max_tokens_reached) - - def _handle_max_tokens_reached(self, event: EventLoopFailureEvent) -> None: - """Handle MaxTokensReachedException by cleaning up orphaned tool uses and allowing continuation.""" - if not isinstance(event.exception, MaxTokensReachedException): - return - - logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") - - incomplete_message: Message = event.exception.incomplete_message - - if not incomplete_message["content"]: - # Cannot correct invalid content block if content is empty - return - - valid_content: list[ContentBlock] = [] - for content in incomplete_message["content"]: - tool_use: ToolUse | None = content.get("toolUse") - if not tool_use: - valid_content.append(content) - continue - - """ - Ideally this would be future proofed using a pydantic validator. Since ToolUse is not implemented - using pydantic, we inspect each field. - """ - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): - """ - If tool_use does not conform to the expected schema it means the max_tokens issue resulted in it not - being populated it correctly. - - It is safe to drop the content block, but we insert a new one to ensure Agent is aware of failure - on the next iteration. - """ - display_name = tool_name if tool_name else "" - logger.warning( - "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name - ) - - valid_content.append( - { - "text": f"The selected tool {display_name}'s tool use was incomplete due " - f"to maximum token limits being reached." - } - ) - else: - # ToolUse was invalid for an unknown reason. Cannot correct, return and allow exception to propagate up. - return - - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} - event.agent.messages.append(valid_message) - event.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=event.agent, message=valid_message)) - event.should_continue_loop = True diff --git a/tests/strands/agent/conversation_manager/__init__.py b/tests/strands/agent/conversation_manager/__init__.py new file mode 100644 index 000000000..d5ee2d119 --- /dev/null +++ b/tests/strands/agent/conversation_manager/__init__.py @@ -0,0 +1 @@ +# Test package for conversation manager diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py new file mode 100644 index 000000000..8d1655c45 --- /dev/null +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -0,0 +1,200 @@ +"""Tests for token limit recovery utility.""" + +import pytest + +from strands.agent.agent import Agent +from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached +from strands.types.content import Message +from strands.types.exceptions import MaxTokensReachedException + + +def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 2 + + # First content block should be preserved + assert corrected_message["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + +def test_recover_from_max_tokens_reached_with_unknown_tool_name(): + """Test recovery when tool use has no name.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in corrected_message["content"][0] + assert "" in corrected_message["content"][0]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + + +def test_recover_from_max_tokens_reached_with_valid_tool_use(): + """Test that valid tool uses are not modified and function returns early.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should not add any message since tool use was valid + assert len(agent.messages) == initial_message_count + + +def test_recover_from_max_tokens_reached_with_empty_content(): + """Test that empty content is handled gracefully.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should not add any message since content is empty + assert len(agent.messages) == initial_message_count + + +def test_recover_from_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 3 + + # First and third content blocks should be preserved + assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} + assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + +def test_recover_from_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + agent = Agent() + initial_message_count = len(agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ] + } + + exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + recover_from_max_tokens_reached(agent, exception) + + # Should add one corrected message + assert len(agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 3 + + # First two content blocks should be preserved exactly + assert corrected_message["content"][0] == {"text": "Here's some text."} + assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in corrected_message["content"][2] + assert "" in corrected_message["content"][2]["text"] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4e310dace..9dd802f4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,7 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -547,6 +547,72 @@ def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): agent("Test!") +def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery(mock_model, agent, agenerator): + """Test that MaxTokensReachedException triggers conversation manager handle_token_limit_reached.""" + conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) + agent.conversation_manager = conversation_manager_spy + + incomplete_message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + mock_model.mock_stream.side_effect = [ + MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ), + agenerator( + [ + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Recovered response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + result = agent("Test message") + + # Verify handle_token_limit_reached was called + assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 + + # Verify the call was made with the correct exception + call_args = conversation_manager_spy.handle_token_limit_reached.call_args + args, kwargs = call_args + assert len(args) >= 2 # Should have at least agent and exception + assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception + + # Verify apply_management was also called + assert conversation_manager_spy.apply_management.call_count > 0 + + # Verify the agent continued and produced a result + assert result is not None + + +def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): + """Test that MaxTokensReachedException with NullConversationManager raises the exception.""" + agent.conversation_manager = NullConversationManager() + + incomplete_message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + mock_model.mock_stream.side_effect = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + with pytest.raises(MaxTokensReachedException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 66eb86808..cd89fbc7a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, Mock import pytest from pydantic import BaseModel @@ -10,7 +10,6 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, - EventLoopFailureEvent, ) from strands.hooks import ( AfterInvocationEvent, @@ -19,7 +18,6 @@ MessageAddedEvent, ) from strands.types.content import Messages -from strands.types.exceptions import MaxTokensReachedException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -37,7 +35,6 @@ def hook_provider(): BeforeModelInvocationEvent, AfterModelInvocationEvent, MessageAddedEvent, - EventLoopFailureEvent, ] ) @@ -295,53 +292,3 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 1 - - -def test_event_loop_failure_event_exception_rethrown_when_not_handled(agent, hook_provider): - """Test that EventLoopFailureEvent is triggered and exceptions are re-thrown when not handled.""" - - # Mock event_loop_cycle to raise a general exception (not ContextWindowOverflowException) - with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: - mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) - - with pytest.raises(MaxTokensReachedException): - agent("test message") - length, events = hook_provider.get_events() - failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] - - assert len(failure_events) == 1 - assert isinstance(failure_events[0].exception, MaxTokensReachedException) - assert failure_events[0].should_continue_loop is False - - -def test_event_loop_failure_event_exception_handled_by_hook(agent, hook_provider): - """Test that EventLoopFailureEvent allows hooks to handle exceptions and continue execution.""" - - first_call = True - - def hook_callback(event: EventLoopFailureEvent): - nonlocal first_call - # Hook handles the exception by setting should_continue_loop to True - event.should_continue_loop = first_call - first_call = False - - agent.hooks.add_callback(EventLoopFailureEvent, hook_callback) - - # Mock event_loop_cycle to raise a general exception - with patch("strands.agent.agent.event_loop_cycle") as mock_cycle: - mock_cycle.side_effect = MaxTokensReachedException("Event loop failure", {"content": [], "role": "assistant"}) - - # Should NOT raise exception due to hook handling on the first failure - with pytest.raises(MaxTokensReachedException): - agent("test message") - - length, events = hook_provider.get_events() - failure_events = [event for event in list(events) if isinstance(event, EventLoopFailureEvent)] - - assert len(failure_events) == 2 - - assert isinstance(failure_events[0].exception, MaxTokensReachedException) - assert failure_events[0].should_continue_loop is True - - assert isinstance(failure_events[1].exception, MaxTokensReachedException) - assert failure_events[1].should_continue_loop is False diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 77d7dcce8..e3452824e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -3,7 +3,9 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager -from strands.types.exceptions import ContextWindowOverflowException +from strands.types.content import Message +from strands.types.exceptions import ContextWindowOverflowException, MaxTokensReachedException, MaxTokensReachedException +from strands.types.content import Message @pytest.fixture @@ -204,6 +206,44 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages +def test_sliding_window_conversation_manager_handle_token_limit_reached(): + """Test that SlidingWindowConversationManager handles token limit recovery.""" + manager = SlidingWindowConversationManager() + test_agent = Agent() + initial_message_count = len(test_agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ] + } + + test_exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + manager.handle_token_limit_reached(test_agent, test_exception) + + # Should add one corrected message + assert len(test_agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = test_agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 2 + + # First content block should be preserved + assert corrected_message["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in corrected_message["content"][1] + assert "calculator" in corrected_message["content"][1]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -246,3 +286,53 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) + + +def test_summarizing_conversation_manager_handle_token_limit_reached(): + """Test that SummarizingConversationManager handles token limit recovery.""" + from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager + + manager = SummarizingConversationManager() + test_agent = Agent() + initial_message_count = len(test_agent.messages) + + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ] + } + + test_exception = MaxTokensReachedException( + message="Token limit reached", + incomplete_message=incomplete_message + ) + + manager.handle_token_limit_reached(test_agent, test_exception) + + # Should add one corrected message + assert len(test_agent.messages) == initial_message_count + 1 + + # Check the corrected message content + corrected_message = test_agent.messages[-1] + assert corrected_message["role"] == "assistant" + assert len(corrected_message["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in corrected_message["content"][0] + assert "" in corrected_message["content"][0]["text"] + assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + + +def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): + """Test that NullConversationManager raises the provided exception.""" + manager = NullConversationManager() + test_agent = Agent() + test_message: Message = { + "role": "assistant", + "content": [{"text": "Hello"}], + } + test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) + + with pytest.raises(MaxTokensReachedException): + manager.handle_token_limit_reached(test_agent, test_exception) diff --git a/tests/strands/experimental/hooks/providers/__init__.py b/tests/strands/experimental/hooks/providers/__init__.py deleted file mode 100644 index 8b1378917..000000000 --- a/tests/strands/experimental/hooks/providers/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py b/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py deleted file mode 100644 index 93d672ab2..000000000 --- a/tests/strands/experimental/hooks/providers/test_correct_tool_use_hook_provider.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Unit tests for CorrectToolUseHookProvider.""" - -from unittest.mock import Mock - -import pytest - -from strands.experimental.hooks.events import EventLoopFailureEvent -from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider -from strands.hooks import HookRegistry -from strands.types.content import Message -from strands.types.exceptions import MaxTokensReachedException - - -@pytest.fixture -def hook_provider(): - """Create a CorrectToolUseHookProvider instance.""" - return CorrectToolUseHookProvider() - - -@pytest.fixture -def mock_agent(): - """Create a mock agent with messages and hooks.""" - agent = Mock() - agent.messages = [] - agent.hooks = Mock() - return agent - - -@pytest.fixture -def mock_registry(): - """Create a mock hook registry.""" - return Mock(spec=HookRegistry) - - -def test_register_hooks(hook_provider, mock_registry): - """Test that the hook provider registers the correct callback.""" - hook_provider.register_hooks(mock_registry) - - mock_registry.add_callback.assert_called_once_with(EventLoopFailureEvent, hook_provider._handle_max_tokens_reached) - - -def test_handle_non_max_tokens_exception(hook_provider, mock_agent): - """Test that non-MaxTokensReachedException events are ignored.""" - other_exception = ValueError("Some other error") - event = EventLoopFailureEvent(agent=mock_agent, exception=other_exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should not modify the agent or event - assert len(mock_agent.messages) == 0 - assert not event.should_continue_loop - mock_agent.hooks.invoke_callbacks.assert_not_called() - - -@pytest.mark.parametrize( - "incomplete_tool_use,expected_tool_name", - [ - ({"toolUseId": "tool-123", "input": {"param": "value"}}, ""), # Missing name - ({"name": "test_tool", "toolUseId": "tool-123"}, "test_tool"), # Missing input - ({"name": "test_tool", "input": {}, "toolUseId": "tool-123"}, "test_tool"), # Empty input - ({"name": "test_tool", "input": {"param": "value"}}, "test_tool"), # Missing toolUseId - ], -) -def test_handle_max_tokens_with_incomplete_tool_use(hook_provider, mock_agent, incomplete_tool_use, expected_tool_name): - """Test handling various incomplete tool use scenarios.""" - incomplete_message: Message = { - "role": "user", # Test role preservation - "content": [{"text": "I'll use a tool"}, {"toolUse": incomplete_tool_use}], - } - - exception = MaxTokensReachedException("Max tokens reached", incomplete_message) - event = EventLoopFailureEvent(agent=mock_agent, exception=exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should add corrected message with error text and preserve role - assert len(mock_agent.messages) == 1 - added_message = mock_agent.messages[0] - assert added_message["role"] == "user" # Role preserved - assert len(added_message["content"]) == 2 - assert added_message["content"][0]["text"] == "I'll use a tool" - assert f"The selected tool {expected_tool_name}'s tool use was incomplete" in added_message["content"][1]["text"] - assert "maximum token limits being reached" in added_message["content"][1]["text"] - - assert event.should_continue_loop - - -def test_handle_max_tokens_with_no_content(hook_provider, mock_agent): - """Test handling message with no content blocks.""" - incomplete_message: Message = {"role": "assistant", "content": []} - - exception = MaxTokensReachedException("Max tokens reached", incomplete_message) - event = EventLoopFailureEvent(agent=mock_agent, exception=exception) - - hook_provider._handle_max_tokens_reached(event) - - # Should add empty message and continue - assert len(mock_agent.messages) == 0 - assert not event.should_continue_loop diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index 6bad70636..d50452801 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -3,7 +3,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.providers.correct_tool_use_hook_provider import CorrectToolUseHookProvider +from strands.agent import NullConversationManager from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException @@ -18,9 +18,9 @@ def story_tool(story: str) -> str: return story -def test_context_window_overflow(): +def test_max_tokens_reached(): model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool]) + agent = Agent(model=model, tools=[story_tool], conversation_manager=NullConversationManager()) with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") @@ -31,8 +31,7 @@ def test_context_window_overflow(): def test_max_tokens_reached_with_hook_provider(): """Test that MaxTokensReachedException can be handled by a hook provider.""" model = BedrockModel(max_tokens=100) - hook_provider = CorrectToolUseHookProvider() - agent = Agent(model=model, tools=[story_tool], hooks=[hook_provider]) + agent = Agent(model=model, tools=[story_tool]) # Defaults to include SlidingWindowConversationManager # This should NOT raise an exception because the hook handles it agent("Tell me a story!") From 5906fc2f8d405c2e1326d9b603707f75123285e7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 5 Aug 2025 18:33:36 -0400 Subject: [PATCH 28/41] linting --- src/strands/agent/agent.py | 5 +- .../token_limit_recovery.py | 10 +- .../test_token_limit_recovery.py | 116 +++++++----------- tests/strands/agent/test_agent.py | 18 ++- .../agent/test_conversation_manager.py | 48 ++++---- 5 files changed, 82 insertions(+), 115 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e258cb324..e749183fc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -53,14 +53,13 @@ T = TypeVar("T", bound=BaseModel) -# Sentinel classes to distinguish between explicit None and default parameter value +# Sentinel class and object to distinguish between explicit None and default parameter value class _DefaultCallbackHandlerSentinel: """Sentinel class to distinguish between explicit None and default parameter value.""" pass - _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" @@ -247,7 +246,7 @@ def __init__( state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. hooks: hooks to be added to the agent hook registry - Defaults to set of if None. + Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. """ diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/token_limit_recovery.py index a0935f3a3..ceb32c735 100644 --- a/src/strands/agent/conversation_manager/token_limit_recovery.py +++ b/src/strands/agent/conversation_manager/token_limit_recovery.py @@ -15,15 +15,15 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. - + This function fixes incomplete tool uses that may occur when the model's response is truncated due to token limits. It: - + 1. Inspects each content block in the incomplete message for invalid tool uses 2. Replaces incomplete tool use blocks with informative text messages 3. Preserves valid content blocks in the corrected message 4. Adds the corrected message to the agent's conversation history - + Args: agent: The agent whose conversation will be updated with the corrected message. exception: The MaxTokensReachedException containing the incomplete message. @@ -48,9 +48,7 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): # Tool use is incomplete due to max_tokens truncation display_name = tool_name if tool_name else "" - logger.warning( - "tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name - ) + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( { diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index 8d1655c45..9ae6c8722 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,6 +1,5 @@ """Tests for token limit recovery utility.""" -import pytest from strands.agent.agent import Agent from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached @@ -12,33 +11,30 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 2 - + # First content block should be preserved assert corrected_message["content"][0] == {"text": "I'll help you with that."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -49,29 +45,26 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 1 - + # Content should be replaced with error message using assert "text" in corrected_message["content"][0] assert "" in corrected_message["content"][0]["text"] @@ -82,22 +75,19 @@ def test_recover_from_max_tokens_reached_with_valid_tool_use(): """Test that valid tool uses are not modified and function returns early.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count @@ -106,19 +96,13 @@ def test_recover_from_max_tokens_reached_with_empty_content(): """Test that empty content is handled gracefully.""" agent = Agent() initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [] - } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + incomplete_message: Message = {"role": "assistant", "content": []} + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should not add any message since content is empty assert len(agent.messages) == initial_message_count @@ -127,35 +111,32 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "Let me calculate this for you."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete {"text": "And then I'll explain the result."}, - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 3 - + # First and third content blocks should be preserved assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -166,35 +147,32 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() initial_message_count = len(agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "Here's some text."}, {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete - ] + ], } - - exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + recover_from_max_tokens_reached(agent, exception) - + # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 3 - + # First two content blocks should be preserved exactly assert corrected_message["content"][0] == {"text": "Here's some text."} assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} - + # Third content block should be replaced with error message assert "text" in corrected_message["content"][2] assert "" in corrected_message["content"][2]["text"] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 9dd802f4e..87aafe7a2 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -557,14 +557,11 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } mock_model.mock_stream.side_effect = [ - MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ), + MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), agenerator( [ {"contentBlockStart": {"start": {}}}, @@ -579,16 +576,16 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( # Verify handle_token_limit_reached was called assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 - + # Verify the call was made with the correct exception call_args = conversation_manager_spy.handle_token_limit_reached.call_args args, kwargs = call_args assert len(args) >= 2 # Should have at least agent and exception assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception - + # Verify apply_management was also called assert conversation_manager_spy.apply_management.call_count > 0 - + # Verify the agent continued and produced a result assert result is not None @@ -601,12 +598,11 @@ def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_e "role": "assistant", "content": [ {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } mock_model.mock_stream.side_effect = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message + message="Token limit reached", incomplete_message=incomplete_message ) with pytest.raises(MaxTokensReachedException): diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index e3452824e..3e5bd56f3 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -4,8 +4,10 @@ from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.types.content import Message -from strands.types.exceptions import ContextWindowOverflowException, MaxTokensReachedException, MaxTokensReachedException -from strands.types.content import Message +from strands.types.exceptions import ( + ContextWindowOverflowException, + MaxTokensReachedException, +) @pytest.fixture @@ -211,33 +213,30 @@ def test_sliding_window_conversation_manager_handle_token_limit_reached(): manager = SlidingWindowConversationManager() test_agent = Agent() initial_message_count = len(test_agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"text": "I'll help you with that."}, {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ] + ], } - - test_exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + manager.handle_token_limit_reached(test_agent, test_exception) - + # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = test_agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 2 - + # First content block should be preserved assert corrected_message["content"][0] == {"text": "I'll help you with that."} - + # Second content block should be replaced with error message assert "text" in corrected_message["content"][1] assert "calculator" in corrected_message["content"][1]["text"] @@ -291,33 +290,30 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): def test_summarizing_conversation_manager_handle_token_limit_reached(): """Test that SummarizingConversationManager handles token limit recovery.""" from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager - + manager = SummarizingConversationManager() test_agent = Agent() initial_message_count = len(test_agent.messages) - + incomplete_message: Message = { "role": "assistant", "content": [ {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ] + ], } - - test_exception = MaxTokensReachedException( - message="Token limit reached", - incomplete_message=incomplete_message - ) - + + test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) + manager.handle_token_limit_reached(test_agent, test_exception) - + # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 - + # Check the corrected message content corrected_message = test_agent.messages[-1] assert corrected_message["role"] == "assistant" assert len(corrected_message["content"]) == 1 - + # Content should be replaced with error message using assert "text" in corrected_message["content"][0] assert "" in corrected_message["content"][0]["text"] From 87445a3224af4d2000b65fc8972abe8d0b9c8220 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 09:45:46 -0400 Subject: [PATCH 29/41] fix: test contained incorrect assertions --- src/strands/agent/agent.py | 4 ++-- .../test_token_limit_recovery.py | 1 - tests/strands/agent/test_agent.py | 16 ++++++---------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e749183fc..044ff4e67 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -585,10 +585,10 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A return except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self, e=e) + self.conversation_manager.reduce_context(agent=self, e=e) except MaxTokensReachedException as e: # Recover conversation state after token limit exceeded, then continue with next cycle - self.conversation_manager.handle_token_limit_reached(self, e=e) + self.conversation_manager.handle_token_limit_reached(agent=self, e=e) # Sync agent after handling exception to keep conversation_manager_state up to date in the session if self._session_manager: diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index 9ae6c8722..afbe73a39 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,6 +1,5 @@ """Tests for token limit recovery utility.""" - from strands.agent.agent import Agent from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached from strands.types.content import Message diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 87aafe7a2..1bc5ad78a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -561,7 +561,9 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( } mock_model.mock_stream.side_effect = [ + # First occurrence MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), + # On retry the loop should succeed agenerator( [ {"contentBlockStart": {"start": {}}}, @@ -572,22 +574,16 @@ def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery( ), ] - result = agent("Test message") + agent("Test message") # Verify handle_token_limit_reached was called assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 # Verify the call was made with the correct exception call_args = conversation_manager_spy.handle_token_limit_reached.call_args - args, kwargs = call_args - assert len(args) >= 2 # Should have at least agent and exception - assert isinstance(args[1], MaxTokensReachedException) # Second argument should be the exception - - # Verify apply_management was also called - assert conversation_manager_spy.apply_management.call_count > 0 - - # Verify the agent continued and produced a result - assert result is not None + kwargs = list(call_args[1].values()) + assert isinstance(kwargs[0], Agent) + assert isinstance(kwargs[1], MaxTokensReachedException) def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): From 924fea9e68ecca7d310bbad1a8b9d607562b76a6 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 10:26:50 -0400 Subject: [PATCH 30/41] fix: add event emission --- .../agent/conversation_manager/__init__.py | 2 + .../conversation_manager.py | 5 +- .../null_conversation_manager.py | 27 +++--- ...recover_tool_use_on_max_tokens_reached.py} | 6 +- .../sliding_window_conversation_manager.py | 4 +- .../summarizing_conversation_manager.py | 4 +- .../test_token_limit_recovery.py | 83 ++++++++++++++++--- 7 files changed, 96 insertions(+), 35 deletions(-) rename src/strands/agent/conversation_manager/{token_limit_recovery.py => recover_tool_use_on_max_tokens_reached.py} (89%) diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..7e7e0c6c5 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -15,12 +15,14 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager __all__ = [ "ConversationManager", "NullConversationManager", + "recover_tool_use_on_max_tokens_reached", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index c2899209b..57ce59f93 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -88,12 +88,11 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs """ pass - @abstractmethod def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: """Called when MaxTokensReachedException is thrown to recover conversation state. This method should implement recovery strategies when the token limit is exceeded and the message array - may be in a broken state. It is called outside the event loop to apply default recovery mechanisms. + may be in a broken state. Args: agent: The agent whose conversation state will be recovered. @@ -101,4 +100,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - pass + raise e diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 29fa1c442..fb9868741 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -45,15 +45,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs else: raise ContextWindowOverflowException("Context window overflowed!") - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Does not handle token limit recovery and raises the exception. - - Args: - agent: The agent whose conversation state will remain unmodified. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - - Raises: - e: The provided exception. - """ - raise e + # + # def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + # """Does not handle token limit recovery and raises the exception. + # + # Args: + # agent: The agent whose conversation state will remain unmodified. + # e: The MaxTokensReachedException that triggered the recovery. + # **kwargs: Additional keyword arguments for future extensibility. + # + # Raises: + # e: The provided exception. + # """ + # raise e diff --git a/src/strands/agent/conversation_manager/token_limit_recovery.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py similarity index 89% rename from src/strands/agent/conversation_manager/token_limit_recovery.py rename to src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index ceb32c735..516c3ec36 100644 --- a/src/strands/agent/conversation_manager/token_limit_recovery.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -3,6 +3,7 @@ import logging from typing import TYPE_CHECKING +from ...hooks import MessageAddedEvent from ...types.content import ContentBlock, Message from ...types.exceptions import MaxTokensReachedException from ...types.tools import ToolUse @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: +def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. This function fixes incomplete tool uses that may occur when the model's response is truncated @@ -28,7 +29,7 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE agent: The agent whose conversation will be updated with the corrected message. exception: The MaxTokensReachedException containing the incomplete message. """ - logger.info("Handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") + logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message @@ -62,3 +63,4 @@ def recover_from_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedE valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f96dbff27..0559e0efa 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -9,7 +9,7 @@ from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager -from .token_limit_recovery import recover_from_max_tokens_reached +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached logger = logging.getLogger(__name__) @@ -187,4 +187,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - recover_from_max_tokens_reached(agent, e) + recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index fe0d13fa4..1dc5d907a 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -8,7 +8,7 @@ from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException from .conversation_manager import ConversationManager -from .token_limit_recovery import recover_from_max_tokens_reached +from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -260,4 +260,4 @@ def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedExceptio e: The MaxTokensReachedException that triggered the recovery. **kwargs: Additional keyword arguments for future extensibility. """ - recover_from_max_tokens_reached(agent, e) + recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py index afbe73a39..006f5db25 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_token_limit_recovery.py @@ -1,14 +1,22 @@ """Tests for token limit recovery utility.""" +from unittest.mock import Mock + from strands.agent.agent import Agent -from strands.agent.conversation_manager.token_limit_recovery import recover_from_max_tokens_reached +from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( + recover_tool_use_on_max_tokens_reached, +) +from strands.hooks import MessageAddedEvent from strands.types.content import Message from strands.types.exceptions import MaxTokensReachedException -def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): +def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -21,7 +29,7 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -39,10 +47,20 @@ def test_recover_from_max_tokens_reached_with_incomplete_tool_use(): assert "calculator" in corrected_message["content"][1]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message + -def test_recover_from_max_tokens_reached_with_unknown_tool_name(): +def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -54,7 +72,7 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -69,10 +87,20 @@ def test_recover_from_max_tokens_reached_with_unknown_tool_name(): assert "" in corrected_message["content"][0]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message -def test_recover_from_max_tokens_reached_with_valid_tool_use(): + +def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): """Test that valid tool uses are not modified and function returns early.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -85,30 +113,42 @@ def test_recover_from_max_tokens_reached_with_valid_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count + # Verify that the MessageAddedEvent callback was NOT invoked + mock_invoke_callbacks.assert_not_called() + -def test_recover_from_max_tokens_reached_with_empty_content(): +def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): """Test that empty content is handled gracefully.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = {"role": "assistant", "content": []} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since content is empty assert len(agent.messages) == initial_message_count + # Verify that the MessageAddedEvent callback was NOT invoked + mock_invoke_callbacks.assert_not_called() -def test_recover_from_max_tokens_reached_with_mixed_content(): + +def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -122,7 +162,7 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -141,10 +181,20 @@ def test_recover_from_max_tokens_reached_with_mixed_content(): assert "calculator" in corrected_message["content"][1]["text"] assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message + -def test_recover_from_max_tokens_reached_preserves_non_tool_content(): +def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() + # Mock the hooks.invoke_callbacks method + mock_invoke_callbacks = Mock() + agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) incomplete_message: Message = { @@ -158,7 +208,7 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_from_max_tokens_reached(agent, exception) + recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -175,3 +225,10 @@ def test_recover_from_max_tokens_reached_preserves_non_tool_content(): # Third content block should be replaced with error message assert "text" in corrected_message["content"][2] assert "" in corrected_message["content"][2]["text"] + + # Verify that the MessageAddedEvent callback was invoked + mock_invoke_callbacks.assert_called_once() + call_args = mock_invoke_callbacks.call_args[0][0] + assert isinstance(call_args, MessageAddedEvent) + assert call_args.agent == agent + assert call_args.message == corrected_message From 104f6b425fbbe5a2414b8f10281469d67d6ab1de Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:03:48 -0400 Subject: [PATCH 31/41] feat: move to async --- src/strands/agent/agent.py | 2 +- .../conversation_manager.py | 2 +- .../null_conversation_manager.py | 14 ------- .../recover_tool_use_on_max_tokens_reached.py | 6 +-- .../sliding_window_conversation_manager.py | 20 +++++----- .../summarizing_conversation_manager.py | 20 +++++----- ...recover_tool_use_on_max_tokens_reached.py} | 38 ++++++++++++------- .../agent/test_conversation_manager.py | 15 +++++--- 8 files changed, 58 insertions(+), 59 deletions(-) rename tests/strands/agent/conversation_manager/{test_token_limit_recovery.py => test_recover_tool_use_on_max_tokens_reached.py} (86%) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 044ff4e67..1f63f7996 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -588,7 +588,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A self.conversation_manager.reduce_context(agent=self, e=e) except MaxTokensReachedException as e: # Recover conversation state after token limit exceeded, then continue with next cycle - self.conversation_manager.handle_token_limit_reached(agent=self, e=e) + await self.conversation_manager.handle_token_limit_reached(agent=self, e=e) # Sync agent after handling exception to keep conversation_manager_state up to date in the session if self._session_manager: diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 57ce59f93..1b42e1fbf 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -88,7 +88,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs """ pass - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: """Called when MaxTokensReachedException is thrown to recover conversation state. This method should implement recovery strategies when the token limit is exceeded and the message array diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index fb9868741..5ff6874e5 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -44,17 +44,3 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs raise e else: raise ContextWindowOverflowException("Context window overflowed!") - - # - # def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - # """Does not handle token limit recovery and raises the exception. - # - # Args: - # agent: The agent whose conversation state will remain unmodified. - # e: The MaxTokensReachedException that triggered the recovery. - # **kwargs: Additional keyword arguments for future extensibility. - # - # Raises: - # e: The provided exception. - # """ - # raise e diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 516c3ec36..e9e056a69 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: +async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. This function fixes incomplete tool uses that may occur when the model's response is truncated @@ -35,7 +35,7 @@ def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensR if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty - return + raise exception valid_content: list[ContentBlock] = [] for content in incomplete_message["content"]: @@ -59,7 +59,7 @@ def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensR ) else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying - return + raise exception valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 0559e0efa..58710493d 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -113,6 +113,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs # Overwrite message history messages[:] = messages[trim_index:] + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply sliding window strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + await recover_tool_use_on_max_tokens_reached(agent, e) + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results in a message to reduce context size. @@ -178,13 +188,3 @@ def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[i return idx return None - - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply sliding window strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 1dc5d907a..1c3dc7d38 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -167,6 +167,16 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e + async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: + """Apply summarization strategy for token limit recovery. + + Args: + agent: The agent whose conversation state will be recovered. + e: The MaxTokensReachedException that triggered the recovery. + **kwargs: Additional keyword arguments for future extensibility. + """ + await recover_tool_use_on_max_tokens_reached(agent, e) + def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. @@ -251,13 +261,3 @@ def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_poin raise ContextWindowOverflowException("Unable to trim conversation context!") return split_point - - def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply summarization strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - recover_tool_use_on_max_tokens_reached(agent, e) diff --git a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py similarity index 86% rename from tests/strands/agent/conversation_manager/test_token_limit_recovery.py rename to tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 006f5db25..77fc35c39 100644 --- a/tests/strands/agent/conversation_manager/test_token_limit_recovery.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -2,6 +2,8 @@ from unittest.mock import Mock +import pytest + from strands.agent.agent import Agent from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( recover_tool_use_on_max_tokens_reached, @@ -11,7 +13,8 @@ from strands.types.exceptions import MaxTokensReachedException -def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): """Test recovery when incomplete tool use is present in the message.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -29,7 +32,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -55,7 +58,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): """Test recovery when tool use has no name.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -72,7 +76,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -95,8 +99,9 @@ def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): - """Test that valid tool uses are not modified and function returns early.""" +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): + """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method mock_invoke_callbacks = Mock() @@ -113,7 +118,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + with pytest.raises(MaxTokensReachedException): + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since tool use was valid assert len(agent.messages) == initial_message_count @@ -122,8 +128,9 @@ def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() -def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): - """Test that empty content is handled gracefully.""" +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): + """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method mock_invoke_callbacks = Mock() @@ -134,7 +141,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + with pytest.raises(MaxTokensReachedException): + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should not add any message since content is empty assert len(agent.messages) == initial_message_count @@ -143,7 +151,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): mock_invoke_callbacks.assert_not_called() -def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): """Test recovery with mix of valid content and incomplete tool use.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -162,7 +171,7 @@ def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 @@ -189,7 +198,8 @@ def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): assert call_args.message == corrected_message -def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): +@pytest.mark.asyncio +async def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): """Test that non-tool content is preserved as-is.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -208,7 +218,7 @@ def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - recover_tool_use_on_max_tokens_reached(agent, exception) + await recover_tool_use_on_max_tokens_reached(agent, exception) # Should add one corrected message assert len(agent.messages) == initial_message_count + 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 3e5bd56f3..83af6c429 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -208,7 +208,8 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages -def test_sliding_window_conversation_manager_handle_token_limit_reached(): +@pytest.mark.asyncio +async def test_sliding_window_conversation_manager_handle_token_limit_reached(): """Test that SlidingWindowConversationManager handles token limit recovery.""" manager = SlidingWindowConversationManager() test_agent = Agent() @@ -224,7 +225,7 @@ def test_sliding_window_conversation_manager_handle_token_limit_reached(): test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 @@ -287,7 +288,8 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): manager.restore_from_session({}) -def test_summarizing_conversation_manager_handle_token_limit_reached(): +@pytest.mark.asyncio +async def test_summarizing_conversation_manager_handle_token_limit_reached(): """Test that SummarizingConversationManager handles token limit recovery.""" from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager @@ -304,7 +306,7 @@ def test_summarizing_conversation_manager_handle_token_limit_reached(): test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) # Should add one corrected message assert len(test_agent.messages) == initial_message_count + 1 @@ -320,7 +322,8 @@ def test_summarizing_conversation_manager_handle_token_limit_reached(): assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] -def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): +@pytest.mark.asyncio +async def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): """Test that NullConversationManager raises the provided exception.""" manager = NullConversationManager() test_agent = Agent() @@ -331,4 +334,4 @@ def test_null_conversation_manager_handle_token_limit_reached_raises_exception() test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) with pytest.raises(MaxTokensReachedException): - manager.handle_token_limit_reached(test_agent, test_exception) + await manager.handle_token_limit_reached(test_agent, test_exception) From 11b91f417c95d50254e159793d1aff027ceacfbb Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:15:52 -0400 Subject: [PATCH 32/41] feat: add additional error case where no tool uses were fixed --- .../recover_tool_use_on_max_tokens_reached.py | 6 ++++++ .../test_recover_tool_use_on_max_tokens_reached.py | 13 ++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index e9e056a69..8fddd4af5 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -38,6 +38,7 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT raise exception valid_content: list[ContentBlock] = [] + has_corrected_content = False for content in incomplete_message["content"]: tool_use: ToolUse | None = content.get("toolUse") if not tool_use: @@ -57,10 +58,15 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT f"to maximum token limits being reached." } ) + has_corrected_content = True else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception + if not has_corrected_content: + # No ToolUse were modified, meaning this method could not have resolved the root cause + raise exception + valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 77fc35c39..8fe576a87 100644 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -128,8 +128,15 @@ async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() +@pytest.mark.parametrize( + "content,description", + [ + ([], "empty content"), + ([{"text": "Just some text with no tools to edit."}], "text-only content"), + ], +) @pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content, description): """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -137,14 +144,14 @@ async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) - incomplete_message: Message = {"role": "assistant", "content": []} + incomplete_message: Message = {"role": "assistant", "content": content} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) with pytest.raises(MaxTokensReachedException): await recover_tool_use_on_max_tokens_reached(agent, exception) - # Should not add any message since content is empty + # Should not add any message since there's nothing to recover assert len(agent.messages) == initial_message_count # Verify that the MessageAddedEvent callback was NOT invoked From 1da9ba76c2ef131825d0fac077a7e6b88c88565e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:32:53 -0400 Subject: [PATCH 33/41] feat: add max tokens reached test --- .../recover_tool_use_on_max_tokens_reached.py | 5 +---- .../test_recover_tool_use_on_max_tokens_reached.py | 11 ++--------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 8fddd4af5..35d597e2a 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -32,6 +32,7 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message + logger.warning(f"incomplete message {incomplete_message}") if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty @@ -63,10 +64,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception - if not has_corrected_content: - # No ToolUse were modified, meaning this method could not have resolved the root cause - raise exception - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} agent.messages.append(valid_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py index 8fe576a87..7d3770699 100644 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py @@ -128,15 +128,8 @@ async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): mock_invoke_callbacks.assert_not_called() -@pytest.mark.parametrize( - "content,description", - [ - ([], "empty content"), - ([{"text": "Just some text with no tools to edit."}], "text-only content"), - ], -) @pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content, description): +async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): """Test that an exception that is raised without recoverability, re-raises exception.""" agent = Agent() # Mock the hooks.invoke_callbacks method @@ -144,7 +137,7 @@ async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(content agent.hooks.invoke_callbacks = mock_invoke_callbacks initial_message_count = len(agent.messages) - incomplete_message: Message = {"role": "assistant", "content": content} + incomplete_message: Message = {"role": "assistant", "content": []} exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) From 623f3c799c9f9fa844b3d86c4a19f086e66a60f3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:34:37 -0400 Subject: [PATCH 34/41] linting --- .../recover_tool_use_on_max_tokens_reached.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 35d597e2a..8c1baa554 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -39,7 +39,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT raise exception valid_content: list[ContentBlock] = [] - has_corrected_content = False for content in incomplete_message["content"]: tool_use: ToolUse | None = content.get("toolUse") if not tool_use: From 66c4c07f6a34ff59cb6d4ca864c63302392f2d53 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 6 Aug 2025 14:38:18 -0400 Subject: [PATCH 35/41] feat: add max tokens reached test --- .../recover_tool_use_on_max_tokens_reached.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py index 8c1baa554..e9e056a69 100644 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py @@ -32,7 +32,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") incomplete_message: Message = exception.incomplete_message - logger.warning(f"incomplete message {incomplete_message}") if not incomplete_message["content"]: # Cannot correct invalid content block if content is empty @@ -58,7 +57,6 @@ async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxT f"to maximum token limits being reached." } ) - has_corrected_content = True else: # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying raise exception From 4b5c5a72dae6617b66ef21672e98f89805849a3d Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 7 Aug 2025 16:50:00 -0400 Subject: [PATCH 36/41] feat: switch to a default behavior to recover from max tokens reached --- src/strands/agent/agent.py | 21 +- .../agent/conversation_manager/__init__.py | 2 - .../conversation_manager.py | 15 - .../recover_tool_use_on_max_tokens_reached.py | 66 ----- .../sliding_window_conversation_manager.py | 13 +- .../summarizing_conversation_manager.py | 13 +- .../_recover_message_on_max_tokens_reached.py | 76 +++++ src/strands/event_loop/event_loop.py | 32 ++- src/strands/types/exceptions.py | 6 +- .../agent/conversation_manager/__init__.py | 1 - ..._recover_tool_use_on_max_tokens_reached.py | 244 ---------------- tests/strands/agent/test_agent.py | 60 +--- .../agent/test_conversation_manager.py | 91 +----- tests/strands/event_loop/test_event_loop.py | 55 ++-- ...t_recover_message_on_max_tokens_reached.py | 267 ++++++++++++++++++ tests_integ/test_max_tokens_reached.py | 24 +- 16 files changed, 417 insertions(+), 569 deletions(-) delete mode 100644 src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py create mode 100644 src/strands/event_loop/_recover_message_on_max_tokens_reached.py delete mode 100644 tests/strands/agent/conversation_manager/__init__.py delete mode 100644 tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py create mode 100644 tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 1f63f7996..111509e3a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,7 +37,7 @@ from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher from ..types.content import ContentBlock, Message, Messages -from ..types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ..types.exceptions import ContextWindowOverflowException from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -582,21 +582,18 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A ) async for event in events: yield event - return + except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(agent=self, e=e) - except MaxTokensReachedException as e: - # Recover conversation state after token limit exceeded, then continue with next cycle - await self.conversation_manager.handle_token_limit_reached(agent=self, e=e) + self.conversation_manager.reduce_context(self, e=e) - # Sync agent after handling exception to keep conversation_manager_state up to date in the session - if self._session_manager: - self._session_manager.sync_agent(self) + # Sync agent after reduce_context to keep conversation_manager_state up to date in the session + if self._session_manager: + self._session_manager.sync_agent(self) - events = self._execute_event_loop_cycle(invocation_state) - async for event in events: - yield event + events = self._execute_event_loop_cycle(invocation_state) + async for event in events: + yield event def _record_tool_execution( self, diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index 7e7e0c6c5..c59623215 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -15,14 +15,12 @@ from .conversation_manager import ConversationManager from .null_conversation_manager import NullConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager __all__ = [ "ConversationManager", "NullConversationManager", - "recover_tool_use_on_max_tokens_reached", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 1b42e1fbf..2c1ee7847 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, Optional from ...types.content import Message -from ...types.exceptions import MaxTokensReachedException if TYPE_CHECKING: from ...agent.agent import Agent @@ -87,17 +86,3 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs **kwargs: Additional keyword arguments for future extensibility. """ pass - - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Called when MaxTokensReachedException is thrown to recover conversation state. - - This method should implement recovery strategies when the token limit is exceeded and the message array - may be in a broken state. - - Args: - agent: The agent whose conversation state will be recovered. - This list is modified in-place. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - raise e diff --git a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py b/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py deleted file mode 100644 index e9e056a69..000000000 --- a/src/strands/agent/conversation_manager/recover_tool_use_on_max_tokens_reached.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Shared utility for handling token limit recovery in conversation managers.""" - -import logging -from typing import TYPE_CHECKING - -from ...hooks import MessageAddedEvent -from ...types.content import ContentBlock, Message -from ...types.exceptions import MaxTokensReachedException -from ...types.tools import ToolUse - -if TYPE_CHECKING: - from ...agent.agent import Agent - -logger = logging.getLogger(__name__) - - -async def recover_tool_use_on_max_tokens_reached(agent: "Agent", exception: MaxTokensReachedException) -> None: - """Handle MaxTokensReachedException by cleaning up orphaned tool uses and adding corrected message. - - This function fixes incomplete tool uses that may occur when the model's response is truncated - due to token limits. It: - - 1. Inspects each content block in the incomplete message for invalid tool uses - 2. Replaces incomplete tool use blocks with informative text messages - 3. Preserves valid content blocks in the corrected message - 4. Adds the corrected message to the agent's conversation history - - Args: - agent: The agent whose conversation will be updated with the corrected message. - exception: The MaxTokensReachedException containing the incomplete message. - """ - logger.info("handling MaxTokensReachedException - inspecting incomplete message for invalid tool uses") - - incomplete_message: Message = exception.incomplete_message - - if not incomplete_message["content"]: - # Cannot correct invalid content block if content is empty - raise exception - - valid_content: list[ContentBlock] = [] - for content in incomplete_message["content"]: - tool_use: ToolUse | None = content.get("toolUse") - if not tool_use: - valid_content.append(content) - continue - - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if not (tool_name and tool_use.get("input") and tool_use.get("toolUseId")): - # Tool use is incomplete due to max_tokens truncation - display_name = tool_name if tool_name else "" - logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) - - valid_content.append( - { - "text": f"The selected tool {display_name}'s tool use was incomplete due " - f"to maximum token limits being reached." - } - ) - else: - # ToolUse was invalid for an unknown reason. Cannot correct, return without modifying - raise exception - - valid_message: Message = {"content": valid_content, "role": incomplete_message["role"]} - agent.messages.append(valid_message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=valid_message)) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 58710493d..e082abe8e 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -7,9 +7,8 @@ from ...agent.agent import Agent from ...types.content import Messages -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached logger = logging.getLogger(__name__) @@ -113,16 +112,6 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs # Overwrite message history messages[:] = messages[trim_index:] - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply sliding window strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - await recover_tool_use_on_max_tokens_reached(agent, e) - def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: """Truncate tool results in a message to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 1c3dc7d38..60e832215 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -6,9 +6,8 @@ from typing_extensions import override from ...types.content import Message -from ...types.exceptions import ContextWindowOverflowException, MaxTokensReachedException +from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager -from .recover_tool_use_on_max_tokens_reached import recover_tool_use_on_max_tokens_reached if TYPE_CHECKING: from ..agent import Agent @@ -167,16 +166,6 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e - async def handle_token_limit_reached(self, agent: "Agent", e: MaxTokensReachedException, **kwargs: Any) -> None: - """Apply summarization strategy for token limit recovery. - - Args: - agent: The agent whose conversation state will be recovered. - e: The MaxTokensReachedException that triggered the recovery. - **kwargs: Additional keyword arguments for future extensibility. - """ - await recover_tool_use_on_max_tokens_reached(agent, e) - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..e4b208fdb --- /dev/null +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -0,0 +1,76 @@ +"""Message recovery utilities for handling max token limit scenarios. + +This module provides functionality to recover and clean up incomplete messages that occur +when model responses are truncated due to maximum token limits being reached. It specifically +handles cases where tool use blocks are incomplete or malformed due to truncation. +""" + +import logging + +from ..types.content import ContentBlock, Message +from ..types.tools import ToolUse + +logger = logging.getLogger(__name__) + + +def recover_message_on_max_tokens_reached(message: Message) -> Message: + """Recover and clean up incomplete messages when max token limits are reached. + + When a model response is truncated due to maximum token limits, tool use blocks may be + incomplete or malformed. This function inspects the message content and: + + 1. Identifies incomplete tool use blocks (missing name, input, or toolUseId) + 2. Replaces incomplete tool uses with informative error messages + 3. Preserves all valid content blocks (text and complete tool uses) + 4. Returns a cleaned message suitable for conversation history + + This recovery mechanism ensures that the conversation can continue gracefully even when + model responses are truncated, providing clear feedback about what happened. + + Args: + message: The potentially incomplete message from the model that was truncated + due to max token limits. + + Returns: + A cleaned Message with incomplete tool uses replaced by explanatory text content. + The returned message maintains the same role as the input message. + + Example: + If a message contains an incomplete tool use like: + ``` + {"toolUse": {"name": "calculator"}} # missing input and toolUseId + ``` + + It will be replaced with: + ``` + {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} + ``` + """ + logger.info("handling max_tokens stop reason - inspecting incomplete message for invalid tool uses") + + valid_content: list[ContentBlock] = [] + for content in message["content"] or []: + tool_use: ToolUse | None = content.get("toolUse") + if not tool_use: + valid_content.append(content) + continue + + # Check if tool use is incomplete (missing or empty required fields) + tool_name = tool_use.get("name") + if tool_name and tool_use.get("input") and tool_use.get("toolUseId"): + # As far as we can tell, tool use is valid if this condition is true + valid_content.append(content) + continue + + # Tool use is incomplete due to max_tokens truncation + display_name = tool_name if tool_name else "" + logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) + + valid_content.append( + { + "text": f"The selected tool {display_name}'s tool use was incomplete due " + f"to maximum token limits being reached." + } + ) + + return {"content": valid_content, "role": message["role"]} diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ae21d4c6d..b36f73155 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -36,6 +36,7 @@ ) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages if TYPE_CHECKING: @@ -156,6 +157,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) ) + if stop_reason == "max_tokens": + message = recover_message_on_max_tokens_reached(message) + if model_invoke_span: tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) break # Success! Break out of retry loop @@ -192,6 +196,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + # Add message in trace and mark the end of the stream messages trace + stream_trace.add_message(message) + stream_trace.end() + + # Add the response message to the conversation + agent.messages.append(message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + yield {"callback": {"message": message}} + + # Update metrics + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) + if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. @@ -205,21 +222,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> "Agent has reached an unrecoverable state due to max_tokens limit. " "For more information see: " "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ), - incomplete_message=message, + ) ) - # Add message in trace and mark the end of the stream messages trace - stream_trace.add_message(message) - stream_trace.end() - - # Add the response message to the conversation - agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} - - # Update metrics - agent.event_loop_metrics.update_usage(usage) - agent.event_loop_metrics.update_metrics(metrics) # If the model is requesting to use tools if stop_reason == "tool_use": diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 71ea28b9f..90f2b8d7f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,8 +2,6 @@ from typing import Any -from strands.types.content import Message - class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -28,14 +26,12 @@ class MaxTokensReachedException(Exception): the complexity of the response, or when the model naturally reaches its configured output limit during generation. """ - def __init__(self, message: str, incomplete_message: Message): + def __init__(self, message: str): """Initialize the exception with an error message and the incomplete message object. Args: message: The error message describing the token limit issue - incomplete_message: The valid Message object with incomplete content due to token limits """ - self.incomplete_message = incomplete_message super().__init__(message) diff --git a/tests/strands/agent/conversation_manager/__init__.py b/tests/strands/agent/conversation_manager/__init__.py deleted file mode 100644 index d5ee2d119..000000000 --- a/tests/strands/agent/conversation_manager/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Test package for conversation manager diff --git a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py b/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py deleted file mode 100644 index 7d3770699..000000000 --- a/tests/strands/agent/conversation_manager/test_recover_tool_use_on_max_tokens_reached.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Tests for token limit recovery utility.""" - -from unittest.mock import Mock - -import pytest - -from strands.agent.agent import Agent -from strands.agent.conversation_manager.recover_tool_use_on_max_tokens_reached import ( - recover_tool_use_on_max_tokens_reached, -) -from strands.hooks import MessageAddedEvent -from strands.types.content import Message -from strands.types.exceptions import MaxTokensReachedException - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_incomplete_tool_use(): - """Test recovery when incomplete tool use is present in the message.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 2 - - # First content block should be preserved - assert corrected_message["content"][0] == {"text": "I'll help you with that."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_unknown_tool_name(): - """Test recovery when tool use has no name.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 1 - - # Content should be replaced with error message using - assert "text" in corrected_message["content"][0] - assert "" in corrected_message["content"][0]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_valid_tool_use(): - """Test that an exception that is raised without recoverability, re-raises exception.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - with pytest.raises(MaxTokensReachedException): - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should not add any message since tool use was valid - assert len(agent.messages) == initial_message_count - - # Verify that the MessageAddedEvent callback was NOT invoked - mock_invoke_callbacks.assert_not_called() - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_empty_content(): - """Test that an exception that is raised without recoverability, re-raises exception.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = {"role": "assistant", "content": []} - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - with pytest.raises(MaxTokensReachedException): - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should not add any message since there's nothing to recover - assert len(agent.messages) == initial_message_count - - # Verify that the MessageAddedEvent callback was NOT invoked - mock_invoke_callbacks.assert_not_called() - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_with_mixed_content(): - """Test recovery with mix of valid content and incomplete tool use.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "Let me calculate this for you."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete - {"text": "And then I'll explain the result."}, - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 3 - - # First and third content blocks should be preserved - assert corrected_message["content"][0] == {"text": "Let me calculate this for you."} - assert corrected_message["content"][2] == {"text": "And then I'll explain the result."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message - - -@pytest.mark.asyncio -async def test_recover_tool_use_on_max_tokens_reached_preserves_non_tool_content(): - """Test that non-tool content is preserved as-is.""" - agent = Agent() - # Mock the hooks.invoke_callbacks method - mock_invoke_callbacks = Mock() - agent.hooks.invoke_callbacks = mock_invoke_callbacks - initial_message_count = len(agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "Here's some text."}, - {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete - ], - } - - exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await recover_tool_use_on_max_tokens_reached(agent, exception) - - # Should add one corrected message - assert len(agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 3 - - # First two content blocks should be preserved exactly - assert corrected_message["content"][0] == {"text": "Here's some text."} - assert corrected_message["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} - - # Third content block should be replaced with error message - assert "text" in corrected_message["content"][2] - assert "" in corrected_message["content"][2]["text"] - - # Verify that the MessageAddedEvent callback was invoked - mock_invoke_callbacks.assert_called_once() - call_args = mock_invoke_callbacks.call_args[0][0] - assert isinstance(call_args, MessageAddedEvent) - assert call_args.agent == agent - assert call_args.message == corrected_message diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1bc5ad78a..4e310dace 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,7 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException +from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -547,64 +547,6 @@ def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): agent("Test!") -def test_agent__call__max_tokens_reached_triggers_conversation_manager_recovery(mock_model, agent, agenerator): - """Test that MaxTokensReachedException triggers conversation manager handle_token_limit_reached.""" - conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) - agent.conversation_manager = conversation_manager_spy - - incomplete_message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - mock_model.mock_stream.side_effect = [ - # First occurrence - MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message), - # On retry the loop should succeed - agenerator( - [ - {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"text": "Recovered response"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ] - ), - ] - - agent("Test message") - - # Verify handle_token_limit_reached was called - assert conversation_manager_spy.handle_token_limit_reached.call_count == 1 - - # Verify the call was made with the correct exception - call_args = conversation_manager_spy.handle_token_limit_reached.call_args - kwargs = list(call_args[1].values()) - assert isinstance(kwargs[0], Agent) - assert isinstance(kwargs[1], MaxTokensReachedException) - - -def test_agent__call__max_tokens_reached_with_null_conversation_manager_raises_exception(mock_model, agent): - """Test that MaxTokensReachedException with NullConversationManager raises the exception.""" - agent.conversation_manager = NullConversationManager() - - incomplete_message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - mock_model.mock_stream.side_effect = MaxTokensReachedException( - message="Token limit reached", incomplete_message=incomplete_message - ) - - with pytest.raises(MaxTokensReachedException): - agent("Test!") - - def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 83af6c429..77d7dcce8 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -3,11 +3,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager -from strands.types.content import Message -from strands.types.exceptions import ( - ContextWindowOverflowException, - MaxTokensReachedException, -) +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -208,42 +204,6 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages == expected_messages -@pytest.mark.asyncio -async def test_sliding_window_conversation_manager_handle_token_limit_reached(): - """Test that SlidingWindowConversationManager handles token limit recovery.""" - manager = SlidingWindowConversationManager() - test_agent = Agent() - initial_message_count = len(test_agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"text": "I'll help you with that."}, - {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId - ], - } - - test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await manager.handle_token_limit_reached(test_agent, test_exception) - - # Should add one corrected message - assert len(test_agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = test_agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 2 - - # First content block should be preserved - assert corrected_message["content"][0] == {"text": "I'll help you with that."} - - # Second content block should be replaced with error message - assert "text" in corrected_message["content"][1] - assert "calculator" in corrected_message["content"][1]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][1]["text"] - - def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -286,52 +246,3 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) - - -@pytest.mark.asyncio -async def test_summarizing_conversation_manager_handle_token_limit_reached(): - """Test that SummarizingConversationManager handles token limit recovery.""" - from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager - - manager = SummarizingConversationManager() - test_agent = Agent() - initial_message_count = len(test_agent.messages) - - incomplete_message: Message = { - "role": "assistant", - "content": [ - {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name - ], - } - - test_exception = MaxTokensReachedException(message="Token limit reached", incomplete_message=incomplete_message) - - await manager.handle_token_limit_reached(test_agent, test_exception) - - # Should add one corrected message - assert len(test_agent.messages) == initial_message_count + 1 - - # Check the corrected message content - corrected_message = test_agent.messages[-1] - assert corrected_message["role"] == "assistant" - assert len(corrected_message["content"]) == 1 - - # Content should be replaced with error message using - assert "text" in corrected_message["content"][0] - assert "" in corrected_message["content"][0]["text"] - assert "incomplete due to maximum token limits" in corrected_message["content"][0]["text"] - - -@pytest.mark.asyncio -async def test_null_conversation_manager_handle_token_limit_reached_raises_exception(): - """Test that NullConversationManager raises the provided exception.""" - manager = NullConversationManager() - test_agent = Agent() - test_message: Message = { - "role": "assistant", - "content": [{"text": "Hello"}], - } - test_exception = MaxTokensReachedException(message="test", incomplete_message=test_message) - - with pytest.raises(MaxTokensReachedException): - await manager.handle_token_limit_reached(test_agent, test_exception) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3886df8b9..191ab51ba 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -305,8 +305,10 @@ async def test_event_loop_cycle_text_response_error( await alist(stream) +@patch("strands.event_loop.event_loop.recover_message_on_max_tokens_reached") @pytest.mark.asyncio async def test_event_loop_cycle_tool_result( + mock_recover_message, agent, model, system_prompt, @@ -339,6 +341,9 @@ async def test_event_loop_cycle_tool_result( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state + # Verify that recover_message_on_max_tokens_reached was NOT called for tool_use stop reason + mock_recover_message.assert_not_called() + model.stream.assert_called_with( [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -568,25 +573,35 @@ async def test_event_loop_cycle_max_tokens_exception( agenerator, alist, ): - """Test that max_tokens stop reason raises MaxTokensReachedException.""" + """Test that max_tokens stop reason calls _recover_message_on_max_tokens_reached then MaxTokensReachedException.""" - # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 - model.stream.return_value = agenerator( - [ - { - "contentBlockStart": { - "start": { - "toolUse": {}, + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "asdf", + "input": {}, # empty + }, + }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "max_tokens"}}, - ] - ) + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ), + ] # Call event_loop_cycle, expecting it to raise MaxTokensReachedException - with pytest.raises(MaxTokensReachedException) as exc_info: + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + with pytest.raises(MaxTokensReachedException, match=expected_message): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -594,16 +609,8 @@ async def test_event_loop_cycle_max_tokens_exception( await alist(stream) # Verify the exception message contains the expected content - expected_message = ( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - assert str(exc_info.value) == expected_message - - # Verify that the message has not been appended to the messages array - assert len(agent.messages) == 1 - assert exc_info.value.incomplete_message not in agent.messages + assert len(agent.messages) == 2 + assert "tool use was incomplete due" in agent.messages[1]["content"][0]["text"] @patch("strands.event_loop.event_loop.get_tracer") diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py new file mode 100644 index 000000000..e751be161 --- /dev/null +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -0,0 +1,267 @@ +"""Tests for token limit recovery utility.""" + +from strands.event_loop._recover_message_on_max_tokens_reached import ( + recover_message_on_max_tokens_reached, +) +from strands.types.content import Message + + +def test_recover_message_on_max_tokens_reached_with_incomplete_tool_use(): + """Test recovery when incomplete tool use is present in the message.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + + # First content block should be preserved + assert result["content"][0] == {"text": "I'll help you with that."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_name(): + """Test recovery when tool use has no name.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message using + assert "text" in result["content"][0] + assert "" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_input(): + """Test recovery when tool use has no input.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "toolUseId": "123"}}, # Missing input + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): + """Test recovery when tool use has no toolUseId.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 1 + + # Content should be replaced with error message + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): + """Test that valid tool uses are preserved unchanged.""" + complete_message: Message = { + "role": "assistant", + "content": [ + {"text": "I'll help you with that."}, + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}}, # Valid + ], + } + + result = recover_message_on_max_tokens_reached(complete_message) + + # Should preserve the message exactly as-is + assert result["role"] == "assistant" + assert len(result["content"]) == 2 + assert result["content"][0] == {"text": "I'll help you with that."} + assert result["content"][1] == { + "toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"} + } + + +def test_recover_message_on_max_tokens_reached_with_empty_content(): + """Test handling of message with empty content.""" + empty_message: Message = {"role": "assistant", "content": []} + + result = recover_message_on_max_tokens_reached(empty_message) + + # Should return message with empty content preserved + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_none_content(): + """Test handling of message with None content.""" + none_content_message: Message = {"role": "assistant", "content": None} + + result = recover_message_on_max_tokens_reached(none_content_message) + + # Should return message with empty content + assert result["role"] == "assistant" + assert result["content"] == [] + + +def test_recover_message_on_max_tokens_reached_with_mixed_content(): + """Test recovery with mix of valid content and incomplete tool use.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Let me calculate this for you."}, + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": ""}}, # Incomplete + {"text": "And then I'll explain the result."}, + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First and third content blocks should be preserved + assert result["content"][0] == {"text": "Let me calculate this for you."} + assert result["content"][2] == {"text": "And then I'll explain the result."} + + # Second content block should be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_non_tool_content(): + """Test that non-tool content is preserved as-is.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"text": "Here's some text."}, + {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "123"}}, # Incomplete + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved exactly + assert result["content"][0] == {"text": "Here's some text."} + assert result["content"][1] == {"image": {"format": "png", "source": {"bytes": "fake_image_data"}}} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): + """Test recovery with multiple incomplete tool uses.""" + incomplete_message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + {"text": "Some text in between."}, + {"toolUse": {"name": "", "input": {}, "toolUseId": "456"}}, # Missing name + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First tool use should be replaced + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + assert "incomplete due to maximum token limits" in result["content"][0]["text"] + + # Text content should be preserved + assert result["content"][1] == {"text": "Some text in between."} + + # Second tool use should be replaced with + assert "text" in result["content"][2] + assert "" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] + + +def test_recover_message_on_max_tokens_reached_preserves_user_role(): + """Test that the function preserves the original message role.""" + incomplete_message: Message = { + "role": "user", + "content": [ + {"toolUse": {"name": "calculator", "input": {}}}, # Missing toolUseId + ], + } + + result = recover_message_on_max_tokens_reached(incomplete_message) + + # Should preserve the original role + assert result["role"] == "user" + assert len(result["content"]) == 1 + assert "text" in result["content"][0] + assert "calculator" in result["content"][0]["text"] + + +def test_recover_message_on_max_tokens_reached_with_content_without_tool_use(): + """Test handling of content blocks that don't have toolUse key.""" + message: Message = { + "role": "assistant", + "content": [ + {"text": "Regular text content."}, + {"someOtherKey": "someValue"}, # Content without toolUse + {"toolUse": {"name": "calculator"}}, # Incomplete tool use + ], + } + + result = recover_message_on_max_tokens_reached(message) + + # Check the corrected message content + assert result["role"] == "assistant" + assert len(result["content"]) == 3 + + # First two content blocks should be preserved + assert result["content"][0] == {"text": "Regular text content."} + assert result["content"][1] == {"someOtherKey": "someValue"} + + # Third content block should be replaced with error message + assert "text" in result["content"][2] + assert "calculator" in result["content"][2]["text"] + assert "incomplete due to maximum token limits" in result["content"][2]["text"] diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py index d50452801..bf5668349 100644 --- a/tests_integ/test_max_tokens_reached.py +++ b/tests_integ/test_max_tokens_reached.py @@ -2,8 +2,8 @@ import pytest +from src.strands.agent import AgentResult from strands import Agent, tool -from strands.agent import NullConversationManager from strands.models.bedrock import BedrockModel from strands.types.exceptions import MaxTokensReachedException @@ -19,23 +19,14 @@ def story_tool(story: str) -> str: def test_max_tokens_reached(): + """Test that MaxTokensReachedException is raised but the agent can still rerun on the second pass""" model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool], conversation_manager=NullConversationManager()) + agent = Agent(model=model, tools=[story_tool]) + # This should raise an exception with pytest.raises(MaxTokensReachedException): agent("Tell me a story!") - assert len(agent.messages) == 1 - - -def test_max_tokens_reached_with_hook_provider(): - """Test that MaxTokensReachedException can be handled by a hook provider.""" - model = BedrockModel(max_tokens=100) - agent = Agent(model=model, tools=[story_tool]) # Defaults to include SlidingWindowConversationManager - - # This should NOT raise an exception because the hook handles it - agent("Tell me a story!") - # Validate that at least one message contains the incomplete tool use error message expected_text = "tool use was incomplete due to maximum token limits being reached" all_text_content = [ @@ -48,3 +39,10 @@ def test_max_tokens_reached_with_hook_provider(): assert any(expected_text in text for text in all_text_content), ( f"Expected to find message containing '{expected_text}' in agent messages" ) + + # Remove tools from agent and re-run with a generic question + agent.tool_registry.registry = {} + agent.tool_registry.tool_config = {} + + result: AgentResult = agent("What is 3+3") + assert result.stop_reason == "end_turn" From 83ad822de0e777a011ffc927286dd748f1e4cc69 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 09:59:41 -0400 Subject: [PATCH 37/41] fix: all tool uses now must be replaced --- .../_recover_message_on_max_tokens_reached.py | 38 +++++++++---------- ...t_recover_message_on_max_tokens_reached.py | 12 +++--- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index e4b208fdb..4282f319d 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -14,31 +14,36 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: - """Recover and clean up incomplete messages when max token limits are reached. + """Recover and clean up messages when max token limits are reached. - When a model response is truncated due to maximum token limits, tool use blocks may be - incomplete or malformed. This function inspects the message content and: + When a model response is truncated due to maximum token limits, all tool use blocks + should be replaced with informative error messages since they may be incomplete or + unreliable. This function inspects the message content and: - 1. Identifies incomplete tool use blocks (missing name, input, or toolUseId) - 2. Replaces incomplete tool uses with informative error messages - 3. Preserves all valid content blocks (text and complete tool uses) + 1. Identifies all tool use blocks (regardless of validity) + 2. Replaces all tool uses with informative error messages + 3. Preserves all non-tool content blocks (text, images, etc.) 4. Returns a cleaned message suitable for conversation history This recovery mechanism ensures that the conversation can continue gracefully even when - model responses are truncated, providing clear feedback about what happened. + model responses are truncated, providing clear feedback about what happened and preventing + potentially incomplete or corrupted tool executions. + + TODO: after https://github.com/strands-agents/sdk-python/issues/561 is completed, only the verifiable + invalid tool_use content blocks need to be replaced. Args: message: The potentially incomplete message from the model that was truncated due to max token limits. Returns: - A cleaned Message with incomplete tool uses replaced by explanatory text content. + A cleaned Message with all tool uses replaced by explanatory text content. The returned message maintains the same role as the input message. Example: - If a message contains an incomplete tool use like: + If a message contains any tool use (complete or incomplete): ``` - {"toolUse": {"name": "calculator"}} # missing input and toolUseId + {"toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"}} ``` It will be replaced with: @@ -46,7 +51,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: {"text": "The selected tool calculator's tool use was incomplete due to maximum token limits being reached."} ``` """ - logger.info("handling max_tokens stop reason - inspecting incomplete message for invalid tool uses") + logger.info("handling max_tokens stop reason - replacing all tool uses with error messages") valid_content: list[ContentBlock] = [] for content in message["content"] or []: @@ -55,15 +60,8 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: valid_content.append(content) continue - # Check if tool use is incomplete (missing or empty required fields) - tool_name = tool_use.get("name") - if tool_name and tool_use.get("input") and tool_use.get("toolUseId"): - # As far as we can tell, tool use is valid if this condition is true - valid_content.append(content) - continue - - # Tool use is incomplete due to max_tokens truncation - display_name = tool_name if tool_name else "" + # Replace all tool uses with error messages when max_tokens is reached + display_name = tool_use.get("name", "") logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py index e751be161..402e90966 100644 --- a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -95,7 +95,7 @@ def test_recover_message_on_max_tokens_reached_with_missing_tool_use_id(): def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): - """Test that valid tool uses are preserved unchanged.""" + """Test that even valid tool uses are replaced with error messages.""" complete_message: Message = { "role": "assistant", "content": [ @@ -106,13 +106,15 @@ def test_recover_message_on_max_tokens_reached_with_valid_tool_use(): result = recover_message_on_max_tokens_reached(complete_message) - # Should preserve the message exactly as-is + # Should replace even valid tool uses with error messages assert result["role"] == "assistant" assert len(result["content"]) == 2 assert result["content"][0] == {"text": "I'll help you with that."} - assert result["content"][1] == { - "toolUse": {"name": "calculator", "input": {"expression": "2+2"}, "toolUseId": "123"} - } + + # Valid tool use should also be replaced with error message + assert "text" in result["content"][1] + assert "calculator" in result["content"][1]["text"] + assert "incomplete due to maximum token limits" in result["content"][1]["text"] def test_recover_message_on_max_tokens_reached_with_empty_content(): From faa4618197a33c7673fb9d66844f66bd795c9a5f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:03:18 -0400 Subject: [PATCH 38/41] fix: boolean --- .../event_loop/_recover_message_on_max_tokens_reached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index 4282f319d..74077042c 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -61,7 +61,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: continue # Replace all tool uses with error messages when max_tokens is reached - display_name = tool_use.get("name", "") + display_name = tool_use.get("name") or "" logger.warning("tool_name=<%s> | replacing with error message due to max_tokens truncation.", display_name) valid_content.append( From fa8195f186ff721f7044dac8db02517333bd17cf Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 8 Aug 2025 10:35:13 -0400 Subject: [PATCH 39/41] remove todo --- .../event_loop/_recover_message_on_max_tokens_reached.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index 74077042c..ab6fb4abe 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -29,9 +29,6 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: model responses are truncated, providing clear feedback about what happened and preventing potentially incomplete or corrupted tool executions. - TODO: after https://github.com/strands-agents/sdk-python/issues/561 is completed, only the verifiable - invalid tool_use content blocks need to be replaced. - Args: message: The potentially incomplete message from the model that was truncated due to max token limits. From d521a2c6fde191c97b84ef3532ab340e0a425d5e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 11:42:56 -0400 Subject: [PATCH 40/41] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 62ed54d47..e0ed01412 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ PyPI version Python versions +

Documentation From e57e3980240aa6154684edf7ce5ac8b54801f103 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 23 Oct 2025 11:43:25 -0400 Subject: [PATCH 41/41] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e0ed01412..5b545f969 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ Python versions + +

Documentation