Skip to content

Commit f5e2070

Browse files
authored
Improve OpenAI error handling (#918)
1 parent 54bc162 commit f5e2070

File tree

3 files changed

+243
-6
lines changed

3 files changed

+243
-6
lines changed

src/strands/models/openai.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing_extensions import Unpack, override
1616

1717
from ..types.content import ContentBlock, Messages
18+
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
1819
from ..types.streaming import StreamEvent
1920
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2021
from ._validation import validate_config_keys
@@ -372,6 +373,10 @@ async def stream(
372373
373374
Yields:
374375
Formatted message chunks from the model.
376+
377+
Raises:
378+
ContextWindowOverflowException: If the input exceeds the model's context window.
379+
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
375380
"""
376381
logger.debug("formatting request")
377382
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
@@ -383,7 +388,20 @@ async def stream(
383388
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
384389
# https://github.com/encode/httpx/discussions/2959.
385390
async with openai.AsyncOpenAI(**self.client_args) as client:
386-
response = await client.chat.completions.create(**request)
391+
try:
392+
response = await client.chat.completions.create(**request)
393+
except openai.BadRequestError as e:
394+
# Check if this is a context length exceeded error
395+
if hasattr(e, "code") and e.code == "context_length_exceeded":
396+
logger.warning("OpenAI threw context window overflow error")
397+
raise ContextWindowOverflowException(str(e)) from e
398+
# Re-raise other BadRequestError exceptions
399+
raise
400+
except openai.RateLimitError as e:
401+
# All rate limit errors should be treated as throttling, not context overflow
402+
# Rate limits (including TPM) require waiting/retrying, not context reduction
403+
logger.warning("OpenAI threw rate limit error")
404+
raise ModelThrottledException(str(e)) from e
387405

388406
logger.debug("got response from model")
389407
yield self.format_chunk({"chunk_type": "message_start"})
@@ -452,16 +470,33 @@ async def structured_output(
452470
453471
Yields:
454472
Model events with the last being the structured output.
473+
474+
Raises:
475+
ContextWindowOverflowException: If the input exceeds the model's context window.
476+
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
455477
"""
456478
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
457479
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
458480
# https://github.com/encode/httpx/discussions/2959.
459481
async with openai.AsyncOpenAI(**self.client_args) as client:
460-
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
461-
model=self.get_config()["model_id"],
462-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
463-
response_format=output_model,
464-
)
482+
try:
483+
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
484+
model=self.get_config()["model_id"],
485+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
486+
response_format=output_model,
487+
)
488+
except openai.BadRequestError as e:
489+
# Check if this is a context length exceeded error
490+
if hasattr(e, "code") and e.code == "context_length_exceeded":
491+
logger.warning("OpenAI threw context window overflow error")
492+
raise ContextWindowOverflowException(str(e)) from e
493+
# Re-raise other BadRequestError exceptions
494+
raise
495+
except openai.RateLimitError as e:
496+
# All rate limit errors should be treated as throttling, not context overflow
497+
# Rate limits (including TPM) require waiting/retrying, not context reduction
498+
logger.warning("OpenAI threw rate limit error")
499+
raise ModelThrottledException(str(e)) from e
465500

466501
parsed: T | None = None
467502
# Find the first choice with tool_calls

tests/strands/models/test_openai.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import unittest.mock
22

3+
import openai
34
import pydantic
45
import pytest
56

67
import strands
78
from strands.models.openai import OpenAIModel
9+
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
810

911

1012
@pytest.fixture
@@ -752,3 +754,149 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings):
752754
model.format_request(messages, tool_choice=None)
753755

754756
assert len(captured_warnings) == 0
757+
758+
759+
@pytest.mark.asyncio
760+
async def test_stream_context_overflow_exception(openai_client, model, messages):
761+
"""Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException."""
762+
# Create a mock OpenAI BadRequestError with context_length_exceeded code
763+
mock_error = openai.BadRequestError(
764+
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
765+
response=unittest.mock.MagicMock(),
766+
body={"error": {"code": "context_length_exceeded"}},
767+
)
768+
mock_error.code = "context_length_exceeded"
769+
770+
# Configure the mock client to raise the context overflow error
771+
openai_client.chat.completions.create.side_effect = mock_error
772+
773+
# Test that the stream method converts the error properly
774+
with pytest.raises(ContextWindowOverflowException) as exc_info:
775+
async for _ in model.stream(messages):
776+
pass
777+
778+
# Verify the exception message contains the original error
779+
assert "maximum context length" in str(exc_info.value)
780+
assert exc_info.value.__cause__ == mock_error
781+
782+
783+
@pytest.mark.asyncio
784+
async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages):
785+
"""Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""
786+
# Create a mock OpenAI BadRequestError with a different error code
787+
mock_error = openai.BadRequestError(
788+
message="Invalid parameter value",
789+
response=unittest.mock.MagicMock(),
790+
body={"error": {"code": "invalid_parameter"}},
791+
)
792+
mock_error.code = "invalid_parameter"
793+
794+
# Configure the mock client to raise the non-context error
795+
openai_client.chat.completions.create.side_effect = mock_error
796+
797+
# Test that other BadRequestError exceptions pass through unchanged
798+
with pytest.raises(openai.BadRequestError) as exc_info:
799+
async for _ in model.stream(messages):
800+
pass
801+
802+
# Verify the original exception is raised, not ContextWindowOverflowException
803+
assert exc_info.value == mock_error
804+
805+
806+
@pytest.mark.asyncio
807+
async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls):
808+
"""Test that structured output also handles context overflow properly."""
809+
# Create a mock OpenAI BadRequestError with context_length_exceeded code
810+
mock_error = openai.BadRequestError(
811+
message="This model's maximum context length is 4096 tokens. However, your messages resulted in 5000 tokens.",
812+
response=unittest.mock.MagicMock(),
813+
body={"error": {"code": "context_length_exceeded"}},
814+
)
815+
mock_error.code = "context_length_exceeded"
816+
817+
# Configure the mock client to raise the context overflow error
818+
openai_client.beta.chat.completions.parse.side_effect = mock_error
819+
820+
# Test that the structured_output method converts the error properly
821+
with pytest.raises(ContextWindowOverflowException) as exc_info:
822+
async for _ in model.structured_output(test_output_model_cls, messages):
823+
pass
824+
825+
# Verify the exception message contains the original error
826+
assert "maximum context length" in str(exc_info.value)
827+
assert exc_info.value.__cause__ == mock_error
828+
829+
830+
@pytest.mark.asyncio
831+
async def test_stream_rate_limit_as_throttle(openai_client, model, messages):
832+
"""Test that all rate limit errors are converted to ModelThrottledException."""
833+
834+
# Create a mock OpenAI RateLimitError (any type of rate limit)
835+
mock_error = openai.RateLimitError(
836+
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
837+
response=unittest.mock.MagicMock(),
838+
body={"error": {"code": "rate_limit_exceeded"}},
839+
)
840+
mock_error.code = "rate_limit_exceeded"
841+
842+
# Configure the mock client to raise the rate limit error
843+
openai_client.chat.completions.create.side_effect = mock_error
844+
845+
# Test that the stream method converts the error properly
846+
with pytest.raises(ModelThrottledException) as exc_info:
847+
async for _ in model.stream(messages):
848+
pass
849+
850+
# Verify the exception message contains the original error
851+
assert "tokens per min" in str(exc_info.value)
852+
assert exc_info.value.__cause__ == mock_error
853+
854+
855+
@pytest.mark.asyncio
856+
async def test_stream_request_rate_limit_as_throttle(openai_client, model, messages):
857+
"""Test that request-based rate limit errors are converted to ModelThrottledException."""
858+
859+
# Create a mock OpenAI RateLimitError for request-based rate limiting
860+
mock_error = openai.RateLimitError(
861+
message="Rate limit reached for requests per minute.",
862+
response=unittest.mock.MagicMock(),
863+
body={"error": {"code": "rate_limit_exceeded"}},
864+
)
865+
mock_error.code = "rate_limit_exceeded"
866+
867+
# Configure the mock client to raise the request rate limit error
868+
openai_client.chat.completions.create.side_effect = mock_error
869+
870+
# Test that the stream method converts the error properly
871+
with pytest.raises(ModelThrottledException) as exc_info:
872+
async for _ in model.stream(messages):
873+
pass
874+
875+
# Verify the exception message contains the original error
876+
assert "Rate limit reached" in str(exc_info.value)
877+
assert exc_info.value.__cause__ == mock_error
878+
879+
880+
@pytest.mark.asyncio
881+
async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls):
882+
"""Test that structured output handles rate limit errors properly."""
883+
884+
# Create a mock OpenAI RateLimitError
885+
mock_error = openai.RateLimitError(
886+
message="Request too large for gpt-4o on tokens per min (TPM): Limit 30000, Requested 117505.",
887+
response=unittest.mock.MagicMock(),
888+
body={"error": {"code": "rate_limit_exceeded"}},
889+
)
890+
mock_error.code = "rate_limit_exceeded"
891+
892+
# Configure the mock client to raise the rate limit error
893+
openai_client.beta.chat.completions.parse.side_effect = mock_error
894+
895+
# Test that the structured_output method converts the error properly
896+
with pytest.raises(ModelThrottledException) as exc_info:
897+
async for _ in model.structured_output(test_output_model_cls, messages):
898+
pass
899+
900+
# Verify the exception message contains the original error
901+
assert "tokens per min" in str(exc_info.value)
902+
assert exc_info.value.__cause__ == mock_error

tests_integ/models/test_model_openai.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
import unittest.mock
23

34
import pydantic
45
import pytest
56

67
import strands
78
from strands import Agent, tool
89
from strands.models.openai import OpenAIModel
10+
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
911
from tests_integ.models import providers
1012

1113
# these tests only run if we have the openai api key
@@ -167,3 +169,55 @@ def tool_with_image_return():
167169
# 'user', but this message with role 'tool' contains an image URL."
168170
# See https://github.com/strands-agents/sdk-python/issues/320 for additional details
169171
agent("Run the the tool and analyze the image")
172+
173+
174+
def test_context_window_overflow_integration():
175+
"""Integration test for context window overflow with OpenAI.
176+
177+
This test verifies that when a request exceeds the model's context window,
178+
the OpenAI model properly raises a ContextWindowOverflowException.
179+
"""
180+
# Use gpt-4o-mini which has a smaller context window to make this test more reliable
181+
mini_model = OpenAIModel(
182+
model_id="gpt-4o-mini-2024-07-18",
183+
client_args={
184+
"api_key": os.getenv("OPENAI_API_KEY"),
185+
},
186+
)
187+
188+
agent = Agent(model=mini_model)
189+
190+
# Create a very long text that should exceed context window
191+
# This text is designed to be long enough to exceed context but not hit token rate limits
192+
long_text = (
193+
"This text is longer than context window, but short enough to not get caught in token rate limit. " * 6800
194+
)
195+
196+
# This should raise ContextWindowOverflowException which gets handled by conversation manager
197+
# The agent should attempt to reduce context and retry
198+
with pytest.raises(ContextWindowOverflowException):
199+
agent(long_text)
200+
201+
202+
def test_rate_limit_throttling_integration_no_retries(model):
203+
"""Integration test for rate limit handling with retries disabled.
204+
205+
This test verifies that when a request exceeds OpenAI's rate limits,
206+
the model properly raises a ModelThrottledException. We disable retries
207+
to avoid waiting for the exponential backoff during testing.
208+
"""
209+
# Patch the event loop constants to disable retries for this test
210+
with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1):
211+
agent = Agent(model=model)
212+
213+
# Create a message that's very long to trigger token-per-minute rate limits
214+
# This should be large enough to exceed TPM limits immediately
215+
very_long_text = "Really long text " * 20000
216+
217+
# This should raise ModelThrottledException without retries
218+
with pytest.raises(ModelThrottledException) as exc_info:
219+
agent(very_long_text)
220+
221+
# Verify it's a rate limit error
222+
error_message = str(exc_info.value).lower()
223+
assert "rate limit" in error_message or "tokens per min" in error_message

0 commit comments

Comments
 (0)