From cd6926f9b3f5124bdfbceb314f4d7ef8b8f635e5 Mon Sep 17 00:00:00 2001 From: "dayuan.jiang" Date: Tue, 14 Oct 2025 23:07:49 +0900 Subject: [PATCH 1/2] feat: add cache_messages --- src/strands/models/bedrock.py | 35 ++++++++++++++++++- tests/strands/models/test_bedrock.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c6a500597..5feea0757 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -72,6 +72,8 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt cache_tools: Cache point type for tools + cache_messages: Cache point type for messages. If set to "default", removes all existing cache points + from messages and adds a cache point at the end of the last message. guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -95,6 +97,7 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Optional[list[str]] cache_prompt: Optional[str] cache_tools: Optional[str] + cache_messages: Optional[str] guardrail_id: Optional[str] guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] @@ -185,6 +188,24 @@ def get_config(self) -> BedrockConfig: """ return self.config + def _remove_cache_points_from_messages(self, messages: Messages) -> Messages: + """Remove all cache points from messages. + + Args: + messages: List of messages to process. + + Returns: + Messages with cache points removed. + """ + cleaned_messages: Messages = [] + for message in messages: + if "content" in message and isinstance(message["content"], list): + cleaned_content = [item for item in message["content"] if "cachePoint" not in item] + cleaned_messages.append({"role": message["role"], "content": cleaned_content}) + else: + cleaned_messages.append(message) + return cleaned_messages + def format_request( self, messages: Messages, @@ -203,9 +224,21 @@ def format_request( Returns: A Bedrock converse stream request. """ + # Handle cache_messages configuration + processed_messages = messages + if self.config.get("cache_messages") == "default": + # Remove all existing cache points from messages + processed_messages = self._remove_cache_points_from_messages(messages) + # Add cache point to the end of the last message + if processed_messages and len(processed_messages) > 0: + last_message = processed_messages[-1] + if "content" in last_message and isinstance(last_message["content"], list): + # Create a new list with the cache point appended + last_message["content"] = [*last_message["content"], {"cachePoint": {"type": "default"}}] + return { "modelId": self.config["model_id"], - "messages": self._format_bedrock_messages(messages), + "messages": self._format_bedrock_messages(processed_messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 96fee67fa..0c97a415c 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -492,6 +492,58 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): assert tru_request == exp_request +def test_format_request_cache_messages(model, model_id, cache_type): + """Test that cache_messages removes existing cache points and adds one at the end.""" + # Messages with existing cache points that should be removed + messages_with_cache = [ + { + "role": "user", + "content": [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, # Should be removed + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Response"}, + {"cachePoint": {"type": "default"}}, # Should be removed + ], + }, + { + "role": "user", + "content": [{"text": "Second message"}], + }, + ] + + model.update_config(cache_messages=cache_type) + tru_request = model.format_request(messages_with_cache) + + # Verify all old cache points are removed and new one is at the end + messages = tru_request["messages"] + + # Check first message has no cache point + assert messages[0]["content"] == [{"text": "First message"}] + + # Check second message has no cache point + assert messages[1]["content"] == [{"text": "Response"}] + + # Check last message has cache point at the end + assert messages[2]["content"] == [ + {"text": "Second message"}, + {"cachePoint": {"type": cache_type}}, + ] + + # Verify the full request structure + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + assert tru_request == exp_request + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): error_message = "Rate exceeded" From b2207bb7aa902d688c84481c99679496c77587db Mon Sep 17 00:00:00 2001 From: "dayuan.jiang" Date: Thu, 16 Oct 2025 00:27:23 +0900 Subject: [PATCH 2/2] fix: keep the orignal messages not be changed --- src/strands/models/bedrock.py | 35 ++++---------- tests/strands/models/test_bedrock.py | 70 ++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 5feea0757..d0525d5e7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -72,8 +72,8 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt cache_tools: Cache point type for tools - cache_messages: Cache point type for messages. If set to "default", removes all existing cache points - from messages and adds a cache point at the end of the last message. + cache_messages: Cache point type for messages. If set to "default", adds a cache point at the end + of the last message. guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -188,24 +188,6 @@ def get_config(self) -> BedrockConfig: """ return self.config - def _remove_cache_points_from_messages(self, messages: Messages) -> Messages: - """Remove all cache points from messages. - - Args: - messages: List of messages to process. - - Returns: - Messages with cache points removed. - """ - cleaned_messages: Messages = [] - for message in messages: - if "content" in message and isinstance(message["content"], list): - cleaned_content = [item for item in message["content"] if "cachePoint" not in item] - cleaned_messages.append({"role": message["role"], "content": cleaned_content}) - else: - cleaned_messages.append(message) - return cleaned_messages - def format_request( self, messages: Messages, @@ -227,14 +209,15 @@ def format_request( # Handle cache_messages configuration processed_messages = messages if self.config.get("cache_messages") == "default": - # Remove all existing cache points from messages - processed_messages = self._remove_cache_points_from_messages(messages) - # Add cache point to the end of the last message - if processed_messages and len(processed_messages) > 0: + # Add cache point to the end of the last message (create copy to avoid modifying original) + if messages and len(messages) > 0: + # Create a shallow copy of the messages list + processed_messages = list(messages) last_message = processed_messages[-1] if "content" in last_message and isinstance(last_message["content"], list): - # Create a new list with the cache point appended - last_message["content"] = [*last_message["content"], {"cachePoint": {"type": "default"}}] + # Create a new message dict with updated content + new_content = [*last_message["content"], {"cachePoint": {"type": "default"}}] + processed_messages[-1] = {"role": last_message["role"], "content": new_content} return { "modelId": self.config["model_id"], diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0c97a415c..a5bd58ef6 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -493,21 +493,21 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): def test_format_request_cache_messages(model, model_id, cache_type): - """Test that cache_messages removes existing cache points and adds one at the end.""" - # Messages with existing cache points that should be removed + """Test that cache_messages preserves existing cache points and adds one at the end.""" + # Messages with existing cache points that should be preserved messages_with_cache = [ { "role": "user", "content": [ {"text": "First message"}, - {"cachePoint": {"type": "default"}}, # Should be removed + {"cachePoint": {"type": "default"}}, # Should be preserved ], }, { "role": "assistant", "content": [ {"text": "Response"}, - {"cachePoint": {"type": "default"}}, # Should be removed + {"cachePoint": {"type": "default"}}, # Should be preserved ], }, { @@ -519,16 +519,22 @@ def test_format_request_cache_messages(model, model_id, cache_type): model.update_config(cache_messages=cache_type) tru_request = model.format_request(messages_with_cache) - # Verify all old cache points are removed and new one is at the end + # Verify existing cache points are preserved and new one is added at the end messages = tru_request["messages"] - # Check first message has no cache point - assert messages[0]["content"] == [{"text": "First message"}] + # Check first message still has its cache point + assert messages[0]["content"] == [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ] - # Check second message has no cache point - assert messages[1]["content"] == [{"text": "Response"}] + # Check second message still has its cache point + assert messages[1]["content"] == [ + {"text": "Response"}, + {"cachePoint": {"type": "default"}}, + ] - # Check last message has cache point at the end + # Check third message (last) has new cache point at the end assert messages[2]["content"] == [ {"text": "Second message"}, {"cachePoint": {"type": cache_type}}, @@ -544,6 +550,50 @@ def test_format_request_cache_messages(model, model_id, cache_type): assert tru_request == exp_request +def test_format_request_cache_messages_does_not_modify_original(model, cache_type): + """Test that format_request does not modify the original messages when cache_messages is set.""" + # Create original messages + original_messages = [ + { + "role": "user", + "content": [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Response"}, + ], + }, + { + "role": "user", + "content": [{"text": "Second message"}], + }, + ] + + # Create a deep copy for comparison + import copy + + expected_messages = copy.deepcopy(original_messages) + + # Call format_request with cache_messages enabled + model.update_config(cache_messages=cache_type) + _ = model.format_request(original_messages) + + # Verify original messages are unchanged + assert original_messages == expected_messages + + # Verify content lists are unchanged + assert original_messages[0]["content"] == [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ] + assert original_messages[1]["content"] == [{"text": "Response"}] + assert original_messages[2]["content"] == [{"text": "Second message"}] + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): error_message = "Rate exceeded"