Skip to content

Commit 1721d58

Browse files
Fix Celery SoftTimeLimitExceeded exception handling
When using the OpenAI client in Celery tasks with soft time limits, the client's broad exception handling was catching SoftTimeLimitExceeded and treating it as a retryable connection error. This prevented Celery tasks from properly handling timeouts and running cleanup logic. This change adds a check to identify termination signals (like Celery's SoftTimeLimitExceeded or asyncio's CancelledError) and re-raises them immediately without retry. This allows task executors to properly handle these signals. Changes: - Added _should_not_retry() helper to identify termination signals - Modified sync and async exception handlers to check before retrying - Added test to verify termination signals are not retried Fixes #2737
1 parent 6574bcd commit 1721d58

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/openai/_base_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@
8888
log: logging.Logger = logging.getLogger(__name__)
8989
log.addFilter(SensitiveHeadersFilter())
9090

91+
92+
def _should_not_retry(exc: Exception) -> bool:
93+
"""
94+
Check if an exception should propagate immediately without retry.
95+
96+
This includes task cancellation signals from async frameworks
97+
and task executors like Celery that should not be caught and retried.
98+
99+
Args:
100+
exc: The exception to check
101+
102+
Returns:
103+
True if the exception should propagate without retry, False otherwise
104+
"""
105+
exc_class = exc.__class__
106+
exc_module = exc_class.__module__
107+
exc_name = exc_class.__name__
108+
109+
# Celery task termination (don't import celery - check by name)
110+
# Examples: SoftTimeLimitExceeded, TimeLimitExceeded, Terminated
111+
if exc_module.startswith("celery") and ("Limit" in exc_name or "Terminated" in exc_name):
112+
return True
113+
114+
# asyncio cancellation
115+
if exc_module == "asyncio" and exc_name == "CancelledError":
116+
return True
117+
118+
return False
119+
120+
91121
# TODO: make base page type vars covariant
92122
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
93123
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
@@ -1001,6 +1031,10 @@ def request(
10011031
except Exception as err:
10021032
log.debug("Encountered Exception", exc_info=True)
10031033

1034+
# Check if this is a termination signal that should not be retried
1035+
if _should_not_retry(err):
1036+
raise
1037+
10041038
if remaining_retries > 0:
10051039
self._sleep_for_retry(
10061040
retries_taken=retries_taken,
@@ -1548,6 +1582,10 @@ async def request(
15481582
except Exception as err:
15491583
log.debug("Encountered Exception", exc_info=True)
15501584

1585+
# Check if this is a termination signal that should not be retried
1586+
if _should_not_retry(err):
1587+
raise
1588+
15511589
if remaining_retries > 0:
15521590
await self._sleep_for_retry(
15531591
retries_taken=retries_taken,

tests/test_client.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,37 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
885885

886886
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
887887

888+
@mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
889+
@pytest.mark.respx(base_url=base_url)
890+
def test_termination_signal_not_retried(self, respx_mock: MockRouter, client: OpenAI) -> None:
891+
"""Test that termination signals (like Celery's SoftTimeLimitExceeded) are not retried."""
892+
client = client.with_options(max_retries=3)
893+
894+
# Create a mock exception that mimics Celery's SoftTimeLimitExceeded
895+
class MockCelerySoftTimeLimitExceeded(Exception):
896+
"""Mock of celery.exceptions.SoftTimeLimitExceeded"""
897+
898+
__module__ = "celery.exceptions"
899+
__name__ = "SoftTimeLimitExceeded"
900+
901+
# Mock the request to raise our termination signal
902+
respx_mock.post("/chat/completions").mock(side_effect=MockCelerySoftTimeLimitExceeded("Time limit exceeded"))
903+
904+
# Verify the exception propagates without retry
905+
with pytest.raises(MockCelerySoftTimeLimitExceeded):
906+
client.chat.completions.create(
907+
messages=[
908+
{
909+
"content": "string",
910+
"role": "developer",
911+
}
912+
],
913+
model="gpt-4o",
914+
)
915+
916+
# Verify only one attempt was made (no retries)
917+
assert len(respx_mock.calls) == 1
918+
888919
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
889920
@mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
890921
@pytest.mark.respx(base_url=base_url)

0 commit comments

Comments
 (0)