diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 83f1258f4..ae0f098b0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -619,6 +619,8 @@ private void writeUtilStubs(Symbol serviceSymbol) { writer.addImport("smithy_http", "tuples_to_fields"); writer.addImport("smithy_http.aio", "HTTPResponse", "_HTTPResponse"); writer.addImport("smithy_core.aio.utils", "async_list"); + writer.addImport("smithy_core.aio.interfaces", "ClientErrorInfo"); + writer.addStdlibImport("typing", "Any"); writer.write(""" class $1L($2T): @@ -634,6 +636,10 @@ class $3L: def __init__(self, *, client_config: HTTPClientConfiguration | None = None): self._client_config = client_config + def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo: + \"\"\"Get information about an exception.\"\"\" + return ClientErrorInfo(is_timeout_error=False) + async def send( self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None ) -> HTTPResponse: @@ -657,6 +663,10 @@ def __init__( self.fields = tuples_to_fields(headers or []) self.body = body + def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo: + \"\"\"Get information about an exception.\"\"\" + return ClientErrorInfo(is_timeout_error=False) + async def send( self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None ) -> _HTTPResponse: diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index bf27c440c..a1d09723e 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -12,7 +12,7 @@ from ..auth import AuthParams from ..deserializers import DeserializeableShape, ShapeDeserializer from ..endpoints import EndpointResolverParams -from ..exceptions import RetryError, SmithyError +from ..exceptions import ClientTimeoutError, RetryError, SmithyError from ..interceptors import ( InputContext, Interceptor, @@ -448,24 +448,32 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( _LOGGER.debug("Sending request %s", request_context.transport_request) - if request_future is not None: - # If we have an input event stream (or duplex event stream) then we - # need to let the client return ASAP so that it can start sending - # events. So here we start the transport send in a background task - # then set the result of the request future. It's important to sequence - # it just like that so that the client gets a stream that's ready - # to send. - transport_task = asyncio.create_task( - self.transport.send(request=request_context.transport_request) - ) - request_future.set_result(request_context) - transport_response = await transport_task - else: - # If we don't have an input stream, there's no point in creating a - # task, so we just immediately await the coroutine. - transport_response = await self.transport.send( - request=request_context.transport_request - ) + try: + if request_future is not None: + # If we have an input event stream (or duplex event stream) then we + # need to let the client return ASAP so that it can start sending + # events. So here we start the transport send in a background task + # then set the result of the request future. It's important to sequence + # it just like that so that the client gets a stream that's ready + # to send. + transport_task = asyncio.create_task( + self.transport.send(request=request_context.transport_request) + ) + request_future.set_result(request_context) + transport_response = await transport_task + else: + # If we don't have an input stream, there's no point in creating a + # task, so we just immediately await the coroutine. + transport_response = await self.transport.send( + request=request_context.transport_request + ) + except Exception as e: + error_info = self.transport.get_error_info(e) + if error_info.is_timeout_error: + raise ClientTimeoutError( + message=f"Client timeout occurred: {e}" + ) from e + raise _LOGGER.debug("Received response: %s", transport_response) diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 31d772125..21d6911ad 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections.abc import AsyncIterable, Callable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from ...documents import TypeRegistry @@ -10,6 +11,15 @@ from ...interfaces import StreamingBlob as SyncStreamingBlob from .eventstream import EventPublisher, EventReceiver + +@dataclass(frozen=True) +class ClientErrorInfo: + """Information about an error from a transport.""" + + is_timeout_error: bool + """Whether this error represents a timeout condition.""" + + if TYPE_CHECKING: from typing_extensions import TypeForm @@ -86,7 +96,23 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin class ClientTransport[I: Request, O: Response](Protocol): - """Protocol-agnostic representation of a client tranport (e.g. an HTTP client).""" + """Protocol-agnostic representation of a client transport (e.g. an HTTP client). + + Transport implementations must define the get_error_info method to determine which + exceptions represent timeout conditions for that transport. + """ + + def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo: + """Get information about an exception. + + Args: + exception: The exception to analyze + **kwargs: Additional context for analysis + + Returns: + ClientErrorInfo with timeout information. + """ + ... async def send(self, request: I) -> O: """Send a request over the transport and receive the response.""" diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 0e28bd530..0a99976f9 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -50,6 +50,9 @@ class CallError(SmithyError): is_throttling_error: bool = False """Whether the error is a throttling error.""" + is_timeout_error: bool = False + """Whether the error represents a timeout condition.""" + def __post_init__(self): super().__init__(self.message) @@ -61,6 +64,20 @@ class ModeledError(CallError): fault: Fault = "client" +@dataclass(kw_only=True) +class ClientTimeoutError(CallError): + """Exception raised when a client-side timeout occurs. + + This error indicates that the client transport layer encountered a timeout while + attempting to communicate with the server. This typically occurs when network + requests take longer than the configured timeout period. + """ + + fault: Fault = "client" + is_timeout_error: bool = True + is_retry_safe: bool | None = True + + class SerializationError(SmithyError): """Base exception type for exceptions raised during serialization.""" diff --git a/packages/smithy-http/src/smithy_http/aio/aiohttp.py b/packages/smithy-http/src/smithy_http/aio/aiohttp.py index 83f4c191f..5d80791f9 100644 --- a/packages/smithy-http/src/smithy_http/aio/aiohttp.py +++ b/packages/smithy-http/src/smithy_http/aio/aiohttp.py @@ -20,7 +20,7 @@ except ImportError: HAS_AIOHTTP = False # type: ignore -from smithy_core.aio.interfaces import StreamingBlob +from smithy_core.aio.interfaces import ClientErrorInfo, StreamingBlob from smithy_core.aio.types import AsyncBytesReader from smithy_core.aio.utils import async_list from smithy_core.exceptions import MissingDependencyError @@ -52,6 +52,12 @@ def __post_init__(self) -> None: class AIOHTTPClient(HTTPClient): """Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp.""" + def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo: + if isinstance(exception, TimeoutError): + return ClientErrorInfo(is_timeout_error=True) + + return ClientErrorInfo(is_timeout_error=False) + def __init__( self, *, diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index a450ef9c9..fdee231e9 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -8,6 +8,8 @@ from inspect import iscoroutinefunction from typing import TYPE_CHECKING, Any +from awscrt.exceptions import AwsCrtError + if TYPE_CHECKING: # pyright doesn't like optional imports. This is reasonable because if we use these # in type hints then they'd result in runtime errors. @@ -33,6 +35,7 @@ from smithy_core import interfaces as core_interfaces from smithy_core.aio import interfaces as core_aio_interfaces +from smithy_core.aio.interfaces import ClientErrorInfo from smithy_core.aio.types import AsyncBytesReader from smithy_core.exceptions import MissingDependencyError @@ -133,6 +136,16 @@ class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient): _HTTP_PORT = 80 _HTTPS_PORT = 443 + def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo: + timeout_indicators = ( + "AWS_IO_SOCKET_TIMEOUT", + "AWS_IO_SOCKET_CLOSED", + ) + if isinstance(exception, AwsCrtError) and exception.name in timeout_indicators: + return ClientErrorInfo(is_timeout_error=True) + + return ClientErrorInfo(is_timeout_error=False) + def __init__( self, eventloop: _AWSCRTEventLoop | None = None, diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index cf25036fe..af32cee16 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -215,7 +215,6 @@ async def _create_error( ) return error_shape.deserialize(deserializer) - is_throttle = response.status == 429 message = ( f"Unknown error for operation {operation.schema.id} " f"- status: {response.status}" @@ -224,11 +223,17 @@ async def _create_error( message += f" - id: {error_id}" if response.reason is not None: message += f" - reason: {response.status}" + + is_timeout = response.status == 408 + is_throttle = response.status == 429 + fault = "client" if response.status < 500 else "server" + return CallError( message=message, - fault="client" if response.status < 500 else "server", + fault=fault, is_throttling_error=is_throttle, - is_retry_safe=is_throttle or None, + is_timeout_error=is_timeout, + is_retry_safe=is_throttle or is_timeout or None, ) def _matches_content_type(self, response: HTTPResponse) -> bool: diff --git a/packages/smithy-http/tests/unit/aio/test_protocols.py b/packages/smithy-http/tests/unit/aio/test_protocols.py index ecdb15cfa..7d668cb53 100644 --- a/packages/smithy-http/tests/unit/aio/test_protocols.py +++ b/packages/smithy-http/tests/unit/aio/test_protocols.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any +from unittest.mock import Mock import pytest from smithy_core import URI @@ -11,14 +12,15 @@ from smithy_core.interfaces import URI as URIInterface from smithy_core.schemas import APIOperation from smithy_core.shapes import ShapeID +from smithy_core.types import TypedProperties as ConcreteTypedProperties from smithy_http import Fields -from smithy_http.aio import HTTPRequest +from smithy_http.aio import HTTPRequest, HTTPResponse from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface -from smithy_http.aio.protocols import HttpClientProtocol +from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol -class TestProtocol(HttpClientProtocol): +class MockProtocol(HttpClientProtocol): _id = ShapeID("ns.foo#bar") @property @@ -125,7 +127,7 @@ def deserialize_response( def test_http_protocol_joins_uris( request_uri: URI, endpoint_uri: URI, expected: URI ) -> None: - protocol = TestProtocol() + protocol = MockProtocol() request = HTTPRequest( destination=request_uri, method="GET", @@ -135,3 +137,27 @@ def test_http_protocol_joins_uris( updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint) actual = updated_request.destination assert actual == expected + + +@pytest.mark.asyncio +async def test_http_408_creates_timeout_error() -> None: + protocol = Mock(spec=HttpBindingClientProtocol) + protocol.error_identifier = Mock() + protocol.error_identifier.identify.return_value = None + + response = HTTPResponse(status=408, fields=Fields()) + + error = await HttpBindingClientProtocol._create_error( # type: ignore[reportPrivateUsage] + protocol, + operation=Mock(), + request=HTTPRequest( + destination=URI(host="example.com"), method="POST", fields=Fields() + ), + response=response, + response_body=b"", + error_registry=TypeRegistry({}), + context=ConcreteTypedProperties(), + ) + + assert error.is_timeout_error is True + assert error.fault == "client" diff --git a/packages/smithy-http/tests/unit/aio/test_timeout_errors.py b/packages/smithy-http/tests/unit/aio/test_timeout_errors.py new file mode 100644 index 000000000..01e96d19f --- /dev/null +++ b/packages/smithy-http/tests/unit/aio/test_timeout_errors.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + +import pytest +from smithy_core.aio.interfaces import ClientErrorInfo + +if TYPE_CHECKING: + from smithy_http.aio.aiohttp import AIOHTTPClient + from smithy_http.aio.crt import AWSCRTHTTPClient + +try: + from smithy_http.aio.aiohttp import AIOHTTPClient + + has_aiohttp = True +except ImportError: + has_aiohttp = False + +try: + from awscrt.exceptions import AwsCrtError # type: ignore + from smithy_http.aio.crt import AWSCRTHTTPClient + + has_crt = True +except ImportError: + has_crt = False + + +@pytest.mark.skipif(not has_aiohttp, reason="aiohttp not available") +class TestAIOHTTPTimeoutErrorHandling: + """Test timeout error handling for AIOHTTPClient.""" + + @pytest.fixture + async def client(self) -> "AIOHTTPClient": + return AIOHTTPClient() + + @pytest.mark.asyncio + async def test_timeout_error_detection(self, client: "AIOHTTPClient") -> None: + """Test timeout error detection for standard TimeoutError.""" + timeout_err = TimeoutError("Connection timed out") + result = client.get_error_info(timeout_err) + assert result == ClientErrorInfo(is_timeout_error=True) + + @pytest.mark.asyncio + async def test_non_timeout_error_detection(self, client: "AIOHTTPClient") -> None: + """Test non-timeout error detection.""" + other_err = ValueError("Not a timeout") + result = client.get_error_info(other_err) + assert result == ClientErrorInfo(is_timeout_error=False) + + +@pytest.mark.skipif(not has_crt, reason="AWS CRT not available") +class TestAWSCRTTimeoutErrorHandling: + """Test timeout error handling for AWSCRTHTTPClient.""" + + @pytest.fixture + def client(self) -> "AWSCRTHTTPClient": + return AWSCRTHTTPClient() + + @pytest.mark.parametrize( + "error_name,expected_timeout", + [ + ("AWS_IO_SOCKET_TIMEOUT", True), + ("AWS_IO_SOCKET_CLOSED", True), + ("AWS_IO_SOCKET_CONNECTION_REFUSED", False), + ], + ) + def test_crt_error_detection( + self, client: "AWSCRTHTTPClient", error_name: str, expected_timeout: bool + ) -> None: + """Test CRT error detection for various error types.""" + if not has_crt: + pytest.skip("AWS CRT not available") + + crt_err = AwsCrtError( # type: ignore + code=0, name=error_name, message=f"CRT error: {error_name}" + ) + result = client.get_error_info(crt_err) + assert result == ClientErrorInfo(is_timeout_error=expected_timeout) + + def test_non_crt_error_detection(self, client: "AWSCRTHTTPClient") -> None: + """Test non-CRT error detection.""" + other_err = ValueError("Not a timeout") + result = client.get_error_info(other_err) + assert result == ClientErrorInfo(is_timeout_error=False)