diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 06c9a23604..9d585d5375 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -101,6 +101,30 @@ def get_start_span_function(): return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction +def _truncate_single_message(message, max_bytes): + # type: (Dict[str, Any], int) -> Dict[str, Any] + """ + Truncate a single message to fit within max_bytes. + If the message is too large, truncate the content field. + """ + if not isinstance(message, dict) or "content" not in message: + return message + content = message.get("content", "") + + if not isinstance(content, str) or len(content) <= max_bytes: + return message + + overhead_message = message.copy() + overhead_message["content"] = "" + overhead_size = len( + json.dumps(overhead_message, separators=(",", ":")).encode("utf-8") + ) + + available_content_bytes = max_bytes - overhead_size - 20 + message["content"] = content[:available_content_bytes] + "..." + return message + + def _find_truncation_index(messages, max_bytes): # type: (List[Dict[str, Any]], int) -> int """ @@ -120,14 +144,20 @@ def _find_truncation_index(messages, max_bytes): def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES): # type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int] - serialized_json = json.dumps(messages, separators=(",", ":")) + messages_with_truncated_content = [ + _truncate_single_message(msg, max_bytes) for msg in messages + ] + + serialized_json = json.dumps(messages_with_truncated_content, separators=(",", ":")) current_size = len(serialized_json.encode("utf-8")) if current_size <= max_bytes: - return messages, 0 + return messages_with_truncated_content, 0 - truncation_index = _find_truncation_index(messages, max_bytes) - return messages[truncation_index:], truncation_index + truncation_index = _find_truncation_index( + messages_with_truncated_content, max_bytes + ) + return messages_with_truncated_content[truncation_index:], truncation_index def truncate_and_annotate_messages( diff --git a/tests/test_ai_monitoring.py b/tests/test_ai_monitoring.py index 5ff136f810..16982a2a17 100644 --- a/tests/test_ai_monitoring.py +++ b/tests/test_ai_monitoring.py @@ -278,6 +278,84 @@ def test_progressive_truncation(self, large_messages): assert current_count >= 1 prev_count = current_count + def test_individual_message_truncation(self): + large_content = "This is a very long message. " * 1000 + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": large_content}, + ] + + result, truncation_index = truncate_messages_by_size( + messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES + ) + + assert len(result) > 0 + + total_size = len(json.dumps(result, separators=(",", ":")).encode("utf-8")) + assert total_size <= MAX_GEN_AI_MESSAGE_BYTES + + for msg in result: + msg_size = len(json.dumps(msg, separators=(",", ":")).encode("utf-8")) + assert msg_size <= MAX_GEN_AI_MESSAGE_BYTES + + # If the last message is too large, the system message is not present + system_msgs = [m for m in result if m.get("role") == "system"] + assert len(system_msgs) == 0 + + # Confirm the user message is truncated with '...' + user_msgs = [m for m in result if m.get("role") == "user"] + assert len(user_msgs) == 1 + assert user_msgs[0]["content"].endswith("...") + assert len(user_msgs[0]["content"]) < len(large_content) + + def test_combined_individual_and_array_truncation(self): + huge_content = "X" * 25000 + medium_content = "Y" * 5000 + + messages = [ + {"role": "system", "content": medium_content}, + {"role": "user", "content": huge_content}, + {"role": "assistant", "content": medium_content}, + {"role": "user", "content": "small"}, + ] + + result, truncation_index = truncate_messages_by_size( + messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES + ) + + assert len(result) > 0 + + total_size = len(json.dumps(result, separators=(",", ":")).encode("utf-8")) + assert total_size <= MAX_GEN_AI_MESSAGE_BYTES + + for msg in result: + msg_size = len(json.dumps(msg, separators=(",", ":")).encode("utf-8")) + assert msg_size <= MAX_GEN_AI_MESSAGE_BYTES + + # The last user "small" message should always be present and untruncated + last_user_msgs = [ + m for m in result if m.get("role") == "user" and m["content"] == "small" + ] + assert len(last_user_msgs) == 1 + + # If the huge message is present, it must be truncated + for user_msg in [ + m for m in result if m.get("role") == "user" and "X" in m["content"] + ]: + assert user_msg["content"].endswith("...") + assert len(user_msg["content"]) < len(huge_content) + + # The medium messages, if present, should not be truncated + for expected_role in ["system", "assistant"]: + role_msgs = [m for m in result if m.get("role") == expected_role] + if role_msgs: + assert role_msgs[0]["content"].startswith("Y") + assert len(role_msgs[0]["content"]) <= len(medium_content) + assert not role_msgs[0]["content"].endswith("...") or len( + role_msgs[0]["content"] + ) == len(medium_content) + class TestTruncateAndAnnotateMessages: def test_no_truncation_returns_list(self, sample_messages):