Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat

writer.putContext("operation", symbolProvider.toSymbol(operation));
writer.addImport("smithy_core.aio.client", "ClientCall");
writer.addImport("smithy_core.aio.client", "CLIENT_ID");
writer.addImport("smithy_core.interceptors", "InterceptorChain");
writer.addImport("smithy_core.types", "TypedProperties");
writer.addImport("smithy_core.aio.client", "RequestPipeline");
Expand All @@ -207,12 +208,12 @@ raise ExpectationNotMetError("protocol and transport MUST be set on the config t
call = ClientCall(
input=input,
operation=${operation:T},
context=TypedProperties({"config": config}),
context=TypedProperties({"config": config, CLIENT_ID.key: str(id(self))}),
interceptor=InterceptorChain(config.interceptors),
auth_scheme_resolver=config.auth_scheme_resolver,
supported_auth_schemes=config.auth_schemes,
endpoint_resolver=config.endpoint_resolver,
retry_strategy=config.retry_strategy,
retry_strategy_resolver=config.retry_strategy_resolver,
)
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
} else {
path = "";
}
writer.addImport("smithy_core.retries", "SimpleRetryStrategy");
writer.addImport("smithy_core.retries", "RetryStrategyOptions");
writeClientBlock(context.symbolProvider().toSymbol(service), testCase, Optional.of(() -> {
writer.write("""
config = $T(
endpoint_uri="https://$L/$L",
transport = $T(),
retry_strategy=SimpleRetryStrategy(max_attempts=1),
retry_options=RetryStrategyOptions(max_attempts=1),
)
""",
CodegenUtils.getConfigSymbol(context.settings()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,45 @@ public final class ConfigGenerator implements Runnable {
.initialize(writer -> writer.write("self.interceptors = interceptors or []"))
.build(),
ConfigProperty.builder()
.name("retry_strategy")
.name("retry_strategy_resolver")
.type(Symbol.builder()
.name("RetryStrategy")
.namespace("smithy_core.interfaces.retries", ".")
.addDependency(SmithyPythonDependency.SMITHY_CORE)
.name("RetryStrategyResolver[RetryStrategy]")
.addReference(Symbol.builder()
.name("RetryStrategyResolver")
.namespace("smithy_core.interfaces.retries", ".")
.addDependency(SmithyPythonDependency.SMITHY_CORE)
.build())
.addReference(Symbol.builder()
.name("RetryStrategy")
.namespace("smithy_core.interfaces.retries", ".")
.addDependency(SmithyPythonDependency.SMITHY_CORE)
.build())
.build())
.documentation("The retry strategy resolver for resolving retry strategies per client.")
.nullable(false)
.initialize(writer -> {
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.retries", "CachingRetryStrategyResolver");
writer.write(
"self.retry_strategy_resolver = retry_strategy_resolver or CachingRetryStrategyResolver()");
})
.build(),
ConfigProperty.builder()
.name("retry_options")
.type(Symbol.builder()
.name("RetryStrategyOptions")
.addReference(Symbol.builder()
.name("RetryStrategyOptions")
.namespace("smithy_core.retries", ".")
.addDependency(SmithyPythonDependency.SMITHY_CORE)
.build())
.build())
.documentation("The retry strategy for issuing retry tokens and computing retry delays.")
.documentation("Options for configuring retry behavior.")
.nullable(false)
.initialize(writer -> {
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.retries", "SimpleRetryStrategy");
writer.write("self.retry_strategy = retry_strategy or SimpleRetryStrategy()");
writer.addImport("smithy_core.retries", "RetryStrategyOptions");
writer.write("self.retry_options = retry_options or RetryStrategyOptions()");
})
.build(),
ConfigProperty.builder()
Expand Down Expand Up @@ -379,7 +406,7 @@ private void writeInitParams(PythonWriter writer, Collection<ConfigProperty> pro
}

private void documentProperties(PythonWriter writer, Collection<ConfigProperty> properties) {
writer.writeDocs(() ->{
writer.writeDocs(() -> {
var iter = properties.iterator();
writer.write("\nConstructor.\n");
while (iter.hasNext()) {
Expand Down
14 changes: 10 additions & 4 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from ..interfaces import Endpoint, TypedProperties
from ..interfaces.auth import AuthOption, AuthSchemeResolver
from ..interfaces.retries import RetryStrategy
from ..interfaces.retries import RetryStrategy, RetryStrategyResolver
from ..schemas import APIOperation
from ..serializers import SerializeableShape
from ..shapes import ShapeID
Expand All @@ -44,6 +44,10 @@

AUTH_SCHEME = PropertyKey(key="auth_scheme", value_type=AuthScheme[Any, Any, Any, Any])

CLIENT_ID = PropertyKey(key="client_id", value_type=str)
"""A unique identifier for the client instance.
"""

_UNRESOLVED = URI(host="", path="/")
_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,8 +81,8 @@ class ClientCall[I: SerializeableShape, O: DeserializeableShape]:
endpoint_resolver: EndpointResolver
"""The endpoint resolver for the operation."""

retry_strategy: RetryStrategy
"""The retry strategy to use for the operation."""
retry_strategy_resolver: RetryStrategyResolver[RetryStrategy]
"""The retry strategy resolver for the operation."""

retry_scope: str | None = None
"""The retry scope for the operation."""
Expand Down Expand Up @@ -329,7 +333,9 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
if not call.retryable():
return await self._handle_attempt(call, request_context, request_future)

retry_strategy = call.retry_strategy
retry_strategy = await call.retry_strategy_resolver.resolve_retry_strategy(
properties=request_context.properties
)
retry_token = retry_strategy.acquire_initial_retry_token(
token_scope=call.retry_scope
)
Expand Down
18 changes: 17 additions & 1 deletion packages/smithy-core/src/smithy_core/interfaces/retries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from typing import Literal, Protocol, runtime_checkable

from . import TypedProperties


@runtime_checkable
Expand Down Expand Up @@ -52,6 +54,9 @@ class RetryToken(Protocol):
"""Delay in seconds to wait before the retry attempt."""


RetryStrategyType = Literal["simple"]


class RetryStrategy(Protocol):
"""Issuer of :py:class:`RetryToken`s."""

Expand Down Expand Up @@ -100,3 +105,14 @@ def record_success(self, *, token: RetryToken) -> None:
:param token: The token used for the previous successful attempt.
"""
...


class RetryStrategyResolver[RS: RetryStrategy](Protocol):
"""Used to resolve a RetryStrategy for a given caller."""

async def resolve_retry_strategy(self, *, properties: TypedProperties) -> RS:
"""Resolve the retry strategy for the caller.

:param properties: Properties including caller identification and config.
"""
...
91 changes: 91 additions & 0 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,103 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import random
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import Any, Protocol

from .aio.client import CLIENT_ID
from .exceptions import RetryError
from .interfaces import retries as retries_interface
from .interfaces.retries import RetryStrategy, RetryStrategyResolver, RetryStrategyType
from .types import PropertyKey


@dataclass(kw_only=True, frozen=True)
class RetryStrategyOptions:
"""Options for configuring retry behavior."""

retry_mode: RetryStrategyType = "simple"
"""The retry mode to use."""

max_attempts: int = 3
"""Maximum number of attempts (initial attempt plus retries)."""


class RetryConfig(Protocol):
"""Protocol for config objects that support retry configuration."""

retry_options: RetryStrategyOptions


RETRY_CONFIG = PropertyKey(key="config", value_type=RetryConfig)


class CachingRetryStrategyResolver(RetryStrategyResolver[RetryStrategy]):
"""Caching retry strategy resolver that creates and caches retry strategies per caller.

This resolver maintains a cache of retry strategies keyed by a unique identifier
for each caller. This allows multiple operations from the same caller to share
a single retry strategy instance, which is important for strategies that maintain
state across retries (e.g., token buckets, rate limiters).

"""

def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = {}
self._main_lock = asyncio.Lock()

def __deepcopy__(self, memo: dict[int, Any]) -> "CachingRetryStrategyResolver":
"""Return self to preserve cache across operation-level config copies."""
return self

@lru_cache(maxsize=50)
def _create_retry_strategy_cached(
self, retry_id: str, retry_mode: RetryStrategyType, max_attempts: int
) -> RetryStrategy:
return self._create_retry_strategy(retry_mode, max_attempts)

async def resolve_retry_strategy(
self, *, properties: retries_interface.TypedProperties
) -> RetryStrategy:
"""Get or create a retry strategy for the caller.

:param properties: Properties map that must contain the CLIENT_ID property key
with a unique identifier for the caller, and a "config" key with a
retry_strategy attribute (RetryStrategyOptions) specifying the strategy
configuration. Strategies are cached per client and options combination.
:raises ValueError: If CLIENT_ID is not present in properties.
"""
retry_id = properties.get(CLIENT_ID.key)
if retry_id is None:
raise ValueError(
f"Properties must contain '{CLIENT_ID.key}' key with a unique identifier for the caller"
)

# Get retry options from config
config = properties[RETRY_CONFIG]
options = config.retry_options

async with self._main_lock:
if retry_id not in self._locks:
self._locks[retry_id] = asyncio.Lock()
lock = self._locks[retry_id]

async with lock:
return self._create_retry_strategy_cached(
retry_id, options.retry_mode, options.max_attempts
)

def _create_retry_strategy(
self, retry_mode: RetryStrategyType, max_attempts: int
) -> RetryStrategy:
match retry_mode:
case "simple":
return SimpleRetryStrategy(max_attempts=max_attempts)
case _:
raise ValueError(f"Unknown retry mode: {retry_mode}")


class ExponentialBackoffJitterType(Enum):
Expand Down
102 changes: 100 additions & 2 deletions packages/smithy-core/tests/unit/test_retries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from dataclasses import dataclass, field

import pytest
from smithy_core.aio.client import CLIENT_ID
from smithy_core.exceptions import CallError, RetryError
from smithy_core.retries import ExponentialBackoffJitterType as EBJT
from smithy_core.retries import ExponentialRetryBackoffStrategy, SimpleRetryStrategy
from smithy_core.retries import (
CachingRetryStrategyResolver,
ExponentialRetryBackoffStrategy,
RetryStrategyOptions,
SimpleRetryStrategy,
)
from smithy_core.retries import (
ExponentialBackoffJitterType as EBJT,
)
from smithy_core.types import TypedProperties


@dataclass
class MockConfig:
"""Mock config for testing retry resolver."""

retry_options: RetryStrategyOptions = field(default_factory=RetryStrategyOptions)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,3 +118,83 @@ def test_simple_retry_does_not_retry_unsafe() -> None:
token = strategy.acquire_initial_retry_token()
with pytest.raises(RetryError):
strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error)


@pytest.mark.asyncio
async def test_caching_retry_strategy_default_resolution() -> None:
resolver = CachingRetryStrategyResolver()
properties = TypedProperties(
{CLIENT_ID.key: "test-client-1", "config": MockConfig()}
)

strategy = await resolver.resolve_retry_strategy(properties=properties)

assert isinstance(strategy, SimpleRetryStrategy)


@pytest.mark.asyncio
async def test_caching_retry_strategy_resolver_caches_per_client() -> None:
resolver = CachingRetryStrategyResolver()
config = MockConfig()
properties1 = TypedProperties({CLIENT_ID.key: "test-caller-1", "config": config})
properties2 = TypedProperties({CLIENT_ID.key: "test-caller-2", "config": config})

strategy1a = await resolver.resolve_retry_strategy(properties=properties1)
strategy1b = await resolver.resolve_retry_strategy(properties=properties1)
strategy2 = await resolver.resolve_retry_strategy(properties=properties2)

assert strategy1a is strategy1b
assert strategy1a is not strategy2


@pytest.mark.asyncio
async def test_caching_retry_strategy_resolver_concurrent_access() -> None:
resolver = CachingRetryStrategyResolver()
properties = TypedProperties(
{CLIENT_ID.key: "test-caller-concurrent", "config": MockConfig()}
)

strategies = await asyncio.gather(
resolver.resolve_retry_strategy(properties=properties),
resolver.resolve_retry_strategy(properties=properties),
resolver.resolve_retry_strategy(properties=properties),
)

assert strategies[0] is strategies[1]
assert strategies[1] is strategies[2]


@pytest.mark.asyncio
async def test_caching_retry_strategy_resolver_caches_by_options() -> None:
resolver = CachingRetryStrategyResolver()

config1 = MockConfig(retry_options=RetryStrategyOptions(max_attempts=3))
config2 = MockConfig(retry_options=RetryStrategyOptions(max_attempts=5))

properties1 = TypedProperties({CLIENT_ID.key: "test-client", "config": config1})
properties2 = TypedProperties({CLIENT_ID.key: "test-client", "config": config2})

strategy1 = await resolver.resolve_retry_strategy(properties=properties1)
strategy2 = await resolver.resolve_retry_strategy(properties=properties2)

assert strategy1 is not strategy2
assert strategy1.max_attempts == 3
assert strategy2.max_attempts == 5


@pytest.mark.asyncio
async def test_caching_retry_strategy_resolver_requires_client_id() -> None:
resolver = CachingRetryStrategyResolver()
properties = TypedProperties({})

with pytest.raises(ValueError, match=CLIENT_ID.key):
await resolver.resolve_retry_strategy(properties=properties)


def test_caching_retry_strategy_resolver_survives_deepcopy() -> None:
from copy import deepcopy

resolver = CachingRetryStrategyResolver()
resolver_copy = deepcopy(resolver)

assert resolver is resolver_copy
Loading