|
1 | 1 | import unittest.mock |
2 | 2 |
|
| 3 | +import openai |
3 | 4 | import pydantic |
4 | 5 | import pytest |
5 | 6 |
|
6 | 7 | import strands |
7 | 8 | from strands.models.openai import OpenAIModel |
| 9 | +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException |
8 | 10 |
|
9 | 11 |
|
10 | 12 | @pytest.fixture |
@@ -752,3 +754,149 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): |
752 | 754 | model.format_request(messages, tool_choice=None) |
753 | 755 |
|
754 | 756 | assert len(captured_warnings) == 0 |
| 757 | + |
| 758 | + |
| 759 | +@pytest.mark.asyncio |
| 760 | +async def test_stream_context_overflow_exception(openai_client, model, messages): |
| 761 | + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" |
| 762 | + # Create a mock OpenAI BadRequestError with context_length_exceeded code |
| 763 | + mock_error = openai.BadRequestError( |
| 764 | + message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.", |
| 765 | + response=unittest.mock.MagicMock(), |
| 766 | + body={"error": {"code": "context_length_exceeded"}}, |
| 767 | + ) |
| 768 | + mock_error.code = "context_length_exceeded" |
| 769 | + |
| 770 | + # Configure the mock client to raise the context overflow error |
| 771 | + openai_client.chat.completions.create.side_effect = mock_error |
| 772 | + |
| 773 | + # Test that the stream method converts the error properly |
| 774 | + with pytest.raises(ContextWindowOverflowException) as exc_info: |
| 775 | + async for _ in model.stream(messages): |
| 776 | + pass |
| 777 | + |
| 778 | + # Verify the exception message contains the original error |
| 779 | + assert "maximum context length" in str(exc_info.value) |
| 780 | + assert exc_info.value.__cause__ == mock_error |
| 781 | + |
| 782 | + |
| 783 | +@pytest.mark.asyncio |
| 784 | +async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages): |
| 785 | + """Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException.""" |
| 786 | + # Create a mock OpenAI BadRequestError with a different error code |
| 787 | + mock_error = openai.BadRequestError( |
| 788 | + message="Invalid parameter value", |
| 789 | + response=unittest.mock.MagicMock(), |
| 790 | + body={"error": {"code": "invalid_parameter"}}, |
| 791 | + ) |
| 792 | + mock_error.code = "invalid_parameter" |
| 793 | + |
| 794 | + # Configure the mock client to raise the non-context error |
| 795 | + openai_client.chat.completions.create.side_effect = mock_error |
| 796 | + |
| 797 | + # Test that other BadRequestError exceptions pass through unchanged |
| 798 | + with pytest.raises(openai.BadRequestError) as exc_info: |
| 799 | + async for _ in model.stream(messages): |
| 800 | + pass |
| 801 | + |
| 802 | + # Verify the original exception is raised, not ContextWindowOverflowException |
| 803 | + assert exc_info.value == mock_error |
| 804 | + |
| 805 | + |
| 806 | +@pytest.mark.asyncio |
| 807 | +async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls): |
| 808 | + """Test that structured output also handles context overflow properly.""" |
| 809 | + # Create a mock OpenAI BadRequestError with context_length_exceeded code |
| 810 | + mock_error = openai.BadRequestError( |
| 811 | + message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.", |
| 812 | + response=unittest.mock.MagicMock(), |
| 813 | + body={"error": {"code": "context_length_exceeded"}}, |
| 814 | + ) |
| 815 | + mock_error.code = "context_length_exceeded" |
| 816 | + |
| 817 | + # Configure the mock client to raise the context overflow error |
| 818 | + openai_client.beta.chat.completions.parse.side_effect = mock_error |
| 819 | + |
| 820 | + # Test that the structured_output method converts the error properly |
| 821 | + with pytest.raises(ContextWindowOverflowException) as exc_info: |
| 822 | + async for _ in model.structured_output(test_output_model_cls, messages): |
| 823 | + pass |
| 824 | + |
| 825 | + # Verify the exception message contains the original error |
| 826 | + assert "maximum context length" in str(exc_info.value) |
| 827 | + assert exc_info.value.__cause__ == mock_error |
| 828 | + |
| 829 | + |
| 830 | +@pytest.mark.asyncio |
| 831 | +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): |
| 832 | + """Test that all rate limit errors are converted to ModelThrottledException.""" |
| 833 | + |
| 834 | + # Create a mock OpenAI RateLimitError (any type of rate limit) |
| 835 | + mock_error = openai.RateLimitError( |
| 836 | + message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.", |
| 837 | + response=unittest.mock.MagicMock(), |
| 838 | + body={"error": {"code": "rate_limit_exceeded"}}, |
| 839 | + ) |
| 840 | + mock_error.code = "rate_limit_exceeded" |
| 841 | + |
| 842 | + # Configure the mock client to raise the rate limit error |
| 843 | + openai_client.chat.completions.create.side_effect = mock_error |
| 844 | + |
| 845 | + # Test that the stream method converts the error properly |
| 846 | + with pytest.raises(ModelThrottledException) as exc_info: |
| 847 | + async for _ in model.stream(messages): |
| 848 | + pass |
| 849 | + |
| 850 | + # Verify the exception message contains the original error |
| 851 | + assert "tokens per min" in str(exc_info.value) |
| 852 | + assert exc_info.value.__cause__ == mock_error |
| 853 | + |
| 854 | + |
| 855 | +@pytest.mark.asyncio |
| 856 | +async def test_stream_request_rate_limit_as_throttle(openai_client, model, messages): |
| 857 | + """Test that request-based rate limit errors are converted to ModelThrottledException.""" |
| 858 | + |
| 859 | + # Create a mock OpenAI RateLimitError for request-based rate limiting |
| 860 | + mock_error = openai.RateLimitError( |
| 861 | + message="Rate limit reached for requests per minute.", |
| 862 | + response=unittest.mock.MagicMock(), |
| 863 | + body={"error": {"code": "rate_limit_exceeded"}}, |
| 864 | + ) |
| 865 | + mock_error.code = "rate_limit_exceeded" |
| 866 | + |
| 867 | + # Configure the mock client to raise the request rate limit error |
| 868 | + openai_client.chat.completions.create.side_effect = mock_error |
| 869 | + |
| 870 | + # Test that the stream method converts the error properly |
| 871 | + with pytest.raises(ModelThrottledException) as exc_info: |
| 872 | + async for _ in model.stream(messages): |
| 873 | + pass |
| 874 | + |
| 875 | + # Verify the exception message contains the original error |
| 876 | + assert "Rate limit reached" in str(exc_info.value) |
| 877 | + assert exc_info.value.__cause__ == mock_error |
| 878 | + |
| 879 | + |
| 880 | +@pytest.mark.asyncio |
| 881 | +async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls): |
| 882 | + """Test that structured output handles rate limit errors properly.""" |
| 883 | + |
| 884 | + # Create a mock OpenAI RateLimitError |
| 885 | + mock_error = openai.RateLimitError( |
| 886 | + message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.", |
| 887 | + response=unittest.mock.MagicMock(), |
| 888 | + body={"error": {"code": "rate_limit_exceeded"}}, |
| 889 | + ) |
| 890 | + mock_error.code = "rate_limit_exceeded" |
| 891 | + |
| 892 | + # Configure the mock client to raise the rate limit error |
| 893 | + openai_client.beta.chat.completions.parse.side_effect = mock_error |
| 894 | + |
| 895 | + # Test that the structured_output method converts the error properly |
| 896 | + with pytest.raises(ModelThrottledException) as exc_info: |
| 897 | + async for _ in model.structured_output(test_output_model_cls, messages): |
| 898 | + pass |
| 899 | + |
| 900 | + # Verify the exception message contains the original error |
| 901 | + assert "tokens per min" in str(exc_info.value) |
| 902 | + assert exc_info.value.__cause__ == mock_error |
0 commit comments