Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ private void writeUtilStubs(Symbol serviceSymbol) {
writer.addImport("smithy_http", "tuples_to_fields");
writer.addImport("smithy_http.aio", "HTTPResponse", "_HTTPResponse");
writer.addImport("smithy_core.aio.utils", "async_list");
writer.addImport("smithy_core.aio.interfaces", "ClientErrorInfo");
writer.addStdlibImport("typing", "Any");

writer.write("""
class $1L($2T):
Expand All @@ -634,6 +636,10 @@ class $3L:
def __init__(self, *, client_config: HTTPClientConfiguration | None = None):
self._client_config = client_config

def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo:
\"\"\"Get information about an exception.\"\"\"
return ClientErrorInfo(is_timeout_error=False)

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> HTTPResponse:
Expand All @@ -657,6 +663,10 @@ def __init__(
self.fields = tuples_to_fields(headers or [])
self.body = body

def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo:
\"\"\"Get information about an exception.\"\"\"
return ClientErrorInfo(is_timeout_error=False)

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> _HTTPResponse:
Expand Down
46 changes: 27 additions & 19 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..auth import AuthParams
from ..deserializers import DeserializeableShape, ShapeDeserializer
from ..endpoints import EndpointResolverParams
from ..exceptions import RetryError, SmithyError
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
from ..interceptors import (
InputContext,
Interceptor,
Expand Down Expand Up @@ -448,24 +448,32 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](

_LOGGER.debug("Sending request %s", request_context.transport_request)

if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
try:
if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
except Exception as e:
error_info = self.transport.get_error_info(e)
if error_info.is_timeout_error:
raise ClientTimeoutError(
message=f"Client timeout occurred: {e}"
) from e
raise

_LOGGER.debug("Received response: %s", transport_response)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

from ...documents import TypeRegistry
Expand All @@ -10,6 +11,15 @@
from ...interfaces import StreamingBlob as SyncStreamingBlob
from .eventstream import EventPublisher, EventReceiver


@dataclass(frozen=True)
class ClientErrorInfo:
"""Information about an error from a transport."""

is_timeout_error: bool
"""Whether this error represents a timeout condition."""


if TYPE_CHECKING:
from typing_extensions import TypeForm

Expand Down Expand Up @@ -86,7 +96,23 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin


class ClientTransport[I: Request, O: Response](Protocol):
"""Protocol-agnostic representation of a client tranport (e.g. an HTTP client)."""
"""Protocol-agnostic representation of a client transport (e.g. an HTTP client).

Transport implementations must define the get_error_info method to determine which
exceptions represent timeout conditions for that transport.
"""

def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit hesitant to make this a required piece of ClientTransport. This is breaking for existing versions of smithy_http since clients don't implement this. We need to do one of the following:

  • Make this optional and handle it gracefully
  • Include a breaking changelog entry so we know to version bump properly in the next release

I think I prefer the first option

"""Get information about an exception.

Args:
exception: The exception to analyze
**kwargs: Additional context for analysis

Returns:
ClientErrorInfo with timeout information.
"""
...

async def send(self, request: I) -> O:
"""Send a request over the transport and receive the response."""
Expand Down
17 changes: 17 additions & 0 deletions packages/smithy-core/src/smithy_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class CallError(SmithyError):
is_throttling_error: bool = False
"""Whether the error is a throttling error."""

is_timeout_error: bool = False
"""Whether the error represents a timeout condition."""

def __post_init__(self):
super().__init__(self.message)

Expand All @@ -61,6 +64,20 @@ class ModeledError(CallError):
fault: Fault = "client"


@dataclass(kw_only=True)
class ClientTimeoutError(CallError):
"""Exception raised when a client-side timeout occurs.

This error indicates that the client transport layer encountered a timeout while
attempting to communicate with the server. This typically occurs when network
requests take longer than the configured timeout period.
"""

fault: Fault = "client"
is_timeout_error: bool = True
is_retry_safe: bool | None = True


class SerializationError(SmithyError):
"""Base exception type for exceptions raised during serialization."""

Expand Down
8 changes: 7 additions & 1 deletion packages/smithy-http/src/smithy_http/aio/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError:
HAS_AIOHTTP = False # type: ignore

from smithy_core.aio.interfaces import StreamingBlob
from smithy_core.aio.interfaces import ClientErrorInfo, StreamingBlob
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.aio.utils import async_list
from smithy_core.exceptions import MissingDependencyError
Expand Down Expand Up @@ -52,6 +52,12 @@ def __post_init__(self) -> None:
class AIOHTTPClient(HTTPClient):
"""Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp."""

def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo:
if isinstance(exception, TimeoutError):
return ClientErrorInfo(is_timeout_error=True)

return ClientErrorInfo(is_timeout_error=False)

def __init__(
self,
*,
Expand Down
13 changes: 13 additions & 0 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING, Any

from awscrt.exceptions import AwsCrtError

if TYPE_CHECKING:
# pyright doesn't like optional imports. This is reasonable because if we use these
# in type hints then they'd result in runtime errors.
Expand All @@ -33,6 +35,7 @@

from smithy_core import interfaces as core_interfaces
from smithy_core.aio import interfaces as core_aio_interfaces
from smithy_core.aio.interfaces import ClientErrorInfo
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.exceptions import MissingDependencyError

Expand Down Expand Up @@ -133,6 +136,16 @@ class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient):
_HTTP_PORT = 80
_HTTPS_PORT = 443

def get_error_info(self, exception: Exception, **kwargs: Any) -> ClientErrorInfo:
timeout_indicators = (
"AWS_IO_SOCKET_TIMEOUT",
"AWS_IO_SOCKET_CLOSED",
)
if isinstance(exception, AwsCrtError) and exception.name in timeout_indicators:
return ClientErrorInfo(is_timeout_error=True)

return ClientErrorInfo(is_timeout_error=False)

def __init__(
self,
eventloop: _AWSCRTEventLoop | None = None,
Expand Down
11 changes: 8 additions & 3 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ async def _create_error(
)
return error_shape.deserialize(deserializer)

is_throttle = response.status == 429
message = (
f"Unknown error for operation {operation.schema.id} "
f"- status: {response.status}"
Expand All @@ -224,11 +223,17 @@ async def _create_error(
message += f" - id: {error_id}"
if response.reason is not None:
message += f" - reason: {response.status}"

is_timeout = response.status == 408
is_throttle = response.status == 429
fault = "client" if response.status < 500 else "server"

return CallError(
message=message,
fault="client" if response.status < 500 else "server",
fault=fault,
is_throttling_error=is_throttle,
is_retry_safe=is_throttle or None,
is_timeout_error=is_timeout,
is_retry_safe=is_throttle or is_timeout or None,
)

def _matches_content_type(self, response: HTTPResponse) -> bool:
Expand Down
34 changes: 30 additions & 4 deletions packages/smithy-http/tests/unit/aio/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import Mock

import pytest
from smithy_core import URI
Expand All @@ -11,14 +12,15 @@
from smithy_core.interfaces import URI as URIInterface
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID
from smithy_core.types import TypedProperties as ConcreteTypedProperties
from smithy_http import Fields
from smithy_http.aio import HTTPRequest
from smithy_http.aio import HTTPRequest, HTTPResponse
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
from smithy_http.aio.protocols import HttpClientProtocol
from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol


class TestProtocol(HttpClientProtocol):
class MockProtocol(HttpClientProtocol):
_id = ShapeID("ns.foo#bar")

@property
Expand Down Expand Up @@ -125,7 +127,7 @@ def deserialize_response(
def test_http_protocol_joins_uris(
request_uri: URI, endpoint_uri: URI, expected: URI
) -> None:
protocol = TestProtocol()
protocol = MockProtocol()
request = HTTPRequest(
destination=request_uri,
method="GET",
Expand All @@ -135,3 +137,27 @@ def test_http_protocol_joins_uris(
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
actual = updated_request.destination
assert actual == expected


@pytest.mark.asyncio
async def test_http_408_creates_timeout_error() -> None:
protocol = Mock(spec=HttpBindingClientProtocol)
protocol.error_identifier = Mock()
protocol.error_identifier.identify.return_value = None

response = HTTPResponse(status=408, fields=Fields())

error = await HttpBindingClientProtocol._create_error( # type: ignore[reportPrivateUsage]
protocol,
operation=Mock(),
request=HTTPRequest(
destination=URI(host="example.com"), method="POST", fields=Fields()
),
response=response,
response_body=b"",
error_registry=TypeRegistry({}),
context=ConcreteTypedProperties(),
)

assert error.is_timeout_error is True
assert error.fault == "client"
Loading
Loading