diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 1cac89621..fd7a78b15 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -183,6 +183,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat writer.putContext("operation", symbolProvider.toSymbol(operation)); writer.addImport("smithy_core.aio.client", "ClientCall"); + writer.addImport("smithy_core.aio.client", "CLIENT_ID"); writer.addImport("smithy_core.interceptors", "InterceptorChain"); writer.addImport("smithy_core.types", "TypedProperties"); writer.addImport("smithy_core.aio.client", "RequestPipeline"); @@ -207,12 +208,12 @@ raise ExpectationNotMetError("protocol and transport MUST be set on the config t call = ClientCall( input=input, operation=${operation:T}, - context=TypedProperties({"config": config}), + context=TypedProperties({"config": config, CLIENT_ID.key: str(id(self))}), interceptor=InterceptorChain(config.interceptors), auth_scheme_resolver=config.auth_scheme_resolver, supported_auth_schemes=config.auth_schemes, endpoint_resolver=config.endpoint_resolver, - retry_strategy=config.retry_strategy, + retry_strategy_resolver=config.retry_strategy_resolver, ) """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 83f1258f4..00a7b7108 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -181,13 +181,13 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t } else { path = ""; } - writer.addImport("smithy_core.retries", "SimpleRetryStrategy"); + writer.addImport("smithy_core.retries", "RetryStrategyOptions"); writeClientBlock(context.symbolProvider().toSymbol(service), testCase, Optional.of(() -> { writer.write(""" config = $T( endpoint_uri="https://$L/$L", transport = $T(), - retry_strategy=SimpleRetryStrategy(max_attempts=1), + retry_options=RetryStrategyOptions(max_attempts=1), ) """, CodegenUtils.getConfigSymbol(context.settings()), diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java index 45b6324d7..e6d14bf65 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java @@ -53,18 +53,45 @@ public final class ConfigGenerator implements Runnable { .initialize(writer -> writer.write("self.interceptors = interceptors or []")) .build(), ConfigProperty.builder() - .name("retry_strategy") + .name("retry_strategy_resolver") .type(Symbol.builder() - .name("RetryStrategy") - .namespace("smithy_core.interfaces.retries", ".") - .addDependency(SmithyPythonDependency.SMITHY_CORE) + .name("RetryStrategyResolver[RetryStrategy]") + .addReference(Symbol.builder() + .name("RetryStrategyResolver") + .namespace("smithy_core.interfaces.retries", ".") + .addDependency(SmithyPythonDependency.SMITHY_CORE) + .build()) + .addReference(Symbol.builder() + .name("RetryStrategy") + .namespace("smithy_core.interfaces.retries", ".") + .addDependency(SmithyPythonDependency.SMITHY_CORE) + .build()) + .build()) + .documentation("The retry strategy resolver for resolving retry strategies per client.") + .nullable(false) + .initialize(writer -> { + writer.addDependency(SmithyPythonDependency.SMITHY_CORE); + writer.addImport("smithy_core.retries", "CachingRetryStrategyResolver"); + writer.write( + "self.retry_strategy_resolver = retry_strategy_resolver or CachingRetryStrategyResolver()"); + }) + .build(), + ConfigProperty.builder() + .name("retry_options") + .type(Symbol.builder() + .name("RetryStrategyOptions") + .addReference(Symbol.builder() + .name("RetryStrategyOptions") + .namespace("smithy_core.retries", ".") + .addDependency(SmithyPythonDependency.SMITHY_CORE) + .build()) .build()) - .documentation("The retry strategy for issuing retry tokens and computing retry delays.") + .documentation("Options for configuring retry behavior.") .nullable(false) .initialize(writer -> { writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addImport("smithy_core.retries", "SimpleRetryStrategy"); - writer.write("self.retry_strategy = retry_strategy or SimpleRetryStrategy()"); + writer.addImport("smithy_core.retries", "RetryStrategyOptions"); + writer.write("self.retry_options = retry_options or RetryStrategyOptions()"); }) .build(), ConfigProperty.builder() @@ -379,7 +406,7 @@ private void writeInitParams(PythonWriter writer, Collection pro } private void documentProperties(PythonWriter writer, Collection properties) { - writer.writeDocs(() ->{ + writer.writeDocs(() -> { var iter = properties.iterator(); writer.write("\nConstructor.\n"); while (iter.hasNext()) { diff --git a/packages/smithy-core/src/smithy_core/aio/client.py b/packages/smithy-core/src/smithy_core/aio/client.py index bf27c440c..08a061588 100644 --- a/packages/smithy-core/src/smithy_core/aio/client.py +++ b/packages/smithy-core/src/smithy_core/aio/client.py @@ -22,7 +22,7 @@ ) from ..interfaces import Endpoint, TypedProperties from ..interfaces.auth import AuthOption, AuthSchemeResolver -from ..interfaces.retries import RetryStrategy +from ..interfaces.retries import RetryStrategy, RetryStrategyResolver from ..schemas import APIOperation from ..serializers import SerializeableShape from ..shapes import ShapeID @@ -44,6 +44,10 @@ AUTH_SCHEME = PropertyKey(key="auth_scheme", value_type=AuthScheme[Any, Any, Any, Any]) +CLIENT_ID = PropertyKey(key="client_id", value_type=str) +"""A unique identifier for the client instance. +""" + _UNRESOLVED = URI(host="", path="/") _LOGGER = logging.getLogger(__name__) @@ -77,8 +81,8 @@ class ClientCall[I: SerializeableShape, O: DeserializeableShape]: endpoint_resolver: EndpointResolver """The endpoint resolver for the operation.""" - retry_strategy: RetryStrategy - """The retry strategy to use for the operation.""" + retry_strategy_resolver: RetryStrategyResolver[RetryStrategy] + """The retry strategy resolver for the operation.""" retry_scope: str | None = None """The retry scope for the operation.""" @@ -329,7 +333,9 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape]( if not call.retryable(): return await self._handle_attempt(call, request_context, request_future) - retry_strategy = call.retry_strategy + retry_strategy = await call.retry_strategy_resolver.resolve_retry_strategy( + properties=request_context.properties + ) retry_token = retry_strategy.acquire_initial_retry_token( token_scope=call.retry_scope ) diff --git a/packages/smithy-core/src/smithy_core/interfaces/retries.py b/packages/smithy-core/src/smithy_core/interfaces/retries.py index a5c9d428b..374d10fd2 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/retries.py +++ b/packages/smithy-core/src/smithy_core/interfaces/retries.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable + +from . import TypedProperties @runtime_checkable @@ -52,6 +54,9 @@ class RetryToken(Protocol): """Delay in seconds to wait before the retry attempt.""" +RetryStrategyType = Literal["simple"] + + class RetryStrategy(Protocol): """Issuer of :py:class:`RetryToken`s.""" @@ -100,3 +105,14 @@ def record_success(self, *, token: RetryToken) -> None: :param token: The token used for the previous successful attempt. """ ... + + +class RetryStrategyResolver[RS: RetryStrategy](Protocol): + """Used to resolve a RetryStrategy for a given caller.""" + + async def resolve_retry_strategy(self, *, properties: TypedProperties) -> RS: + """Resolve the retry strategy for the caller. + + :param properties: Properties including caller identification and config. + """ + ... diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index 06bf6f988..170a9e768 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -1,12 +1,103 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import random from collections.abc import Callable from dataclasses import dataclass from enum import Enum +from functools import lru_cache +from typing import Any, Protocol +from .aio.client import CLIENT_ID from .exceptions import RetryError from .interfaces import retries as retries_interface +from .interfaces.retries import RetryStrategy, RetryStrategyResolver, RetryStrategyType +from .types import PropertyKey + + +@dataclass(kw_only=True, frozen=True) +class RetryStrategyOptions: + """Options for configuring retry behavior.""" + + retry_mode: RetryStrategyType = "simple" + """The retry mode to use.""" + + max_attempts: int = 3 + """Maximum number of attempts (initial attempt plus retries).""" + + +class RetryConfig(Protocol): + """Protocol for config objects that support retry configuration.""" + + retry_options: RetryStrategyOptions + + +RETRY_CONFIG = PropertyKey(key="config", value_type=RetryConfig) + + +class CachingRetryStrategyResolver(RetryStrategyResolver[RetryStrategy]): + """Caching retry strategy resolver that creates and caches retry strategies per caller. + + This resolver maintains a cache of retry strategies keyed by a unique identifier + for each caller. This allows multiple operations from the same caller to share + a single retry strategy instance, which is important for strategies that maintain + state across retries (e.g., token buckets, rate limiters). + + """ + + def __init__(self) -> None: + self._locks: dict[str, asyncio.Lock] = {} + self._main_lock = asyncio.Lock() + + def __deepcopy__(self, memo: dict[int, Any]) -> "CachingRetryStrategyResolver": + """Return self to preserve cache across operation-level config copies.""" + return self + + @lru_cache(maxsize=50) + def _create_retry_strategy_cached( + self, retry_id: str, retry_mode: RetryStrategyType, max_attempts: int + ) -> RetryStrategy: + return self._create_retry_strategy(retry_mode, max_attempts) + + async def resolve_retry_strategy( + self, *, properties: retries_interface.TypedProperties + ) -> RetryStrategy: + """Get or create a retry strategy for the caller. + + :param properties: Properties map that must contain the CLIENT_ID property key + with a unique identifier for the caller, and a "config" key with a + retry_strategy attribute (RetryStrategyOptions) specifying the strategy + configuration. Strategies are cached per client and options combination. + :raises ValueError: If CLIENT_ID is not present in properties. + """ + retry_id = properties.get(CLIENT_ID.key) + if retry_id is None: + raise ValueError( + f"Properties must contain '{CLIENT_ID.key}' key with a unique identifier for the caller" + ) + + # Get retry options from config + config = properties[RETRY_CONFIG] + options = config.retry_options + + async with self._main_lock: + if retry_id not in self._locks: + self._locks[retry_id] = asyncio.Lock() + lock = self._locks[retry_id] + + async with lock: + return self._create_retry_strategy_cached( + retry_id, options.retry_mode, options.max_attempts + ) + + def _create_retry_strategy( + self, retry_mode: RetryStrategyType, max_attempts: int + ) -> RetryStrategy: + match retry_mode: + case "simple": + return SimpleRetryStrategy(max_attempts=max_attempts) + case _: + raise ValueError(f"Unknown retry mode: {retry_mode}") class ExponentialBackoffJitterType(Enum): diff --git a/packages/smithy-core/tests/unit/test_retries.py b/packages/smithy-core/tests/unit/test_retries.py index 0b3c23be4..691049ef7 100644 --- a/packages/smithy-core/tests/unit/test_retries.py +++ b/packages/smithy-core/tests/unit/test_retries.py @@ -1,10 +1,28 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio +from dataclasses import dataclass, field import pytest +from smithy_core.aio.client import CLIENT_ID from smithy_core.exceptions import CallError, RetryError -from smithy_core.retries import ExponentialBackoffJitterType as EBJT -from smithy_core.retries import ExponentialRetryBackoffStrategy, SimpleRetryStrategy +from smithy_core.retries import ( + CachingRetryStrategyResolver, + ExponentialRetryBackoffStrategy, + RetryStrategyOptions, + SimpleRetryStrategy, +) +from smithy_core.retries import ( + ExponentialBackoffJitterType as EBJT, +) +from smithy_core.types import TypedProperties + + +@dataclass +class MockConfig: + """Mock config for testing retry resolver.""" + + retry_options: RetryStrategyOptions = field(default_factory=RetryStrategyOptions) @pytest.mark.parametrize( @@ -100,3 +118,83 @@ def test_simple_retry_does_not_retry_unsafe() -> None: token = strategy.acquire_initial_retry_token() with pytest.raises(RetryError): strategy.refresh_retry_token_for_retry(token_to_renew=token, error=error) + + +@pytest.mark.asyncio +async def test_caching_retry_strategy_default_resolution() -> None: + resolver = CachingRetryStrategyResolver() + properties = TypedProperties( + {CLIENT_ID.key: "test-client-1", "config": MockConfig()} + ) + + strategy = await resolver.resolve_retry_strategy(properties=properties) + + assert isinstance(strategy, SimpleRetryStrategy) + + +@pytest.mark.asyncio +async def test_caching_retry_strategy_resolver_caches_per_client() -> None: + resolver = CachingRetryStrategyResolver() + config = MockConfig() + properties1 = TypedProperties({CLIENT_ID.key: "test-caller-1", "config": config}) + properties2 = TypedProperties({CLIENT_ID.key: "test-caller-2", "config": config}) + + strategy1a = await resolver.resolve_retry_strategy(properties=properties1) + strategy1b = await resolver.resolve_retry_strategy(properties=properties1) + strategy2 = await resolver.resolve_retry_strategy(properties=properties2) + + assert strategy1a is strategy1b + assert strategy1a is not strategy2 + + +@pytest.mark.asyncio +async def test_caching_retry_strategy_resolver_concurrent_access() -> None: + resolver = CachingRetryStrategyResolver() + properties = TypedProperties( + {CLIENT_ID.key: "test-caller-concurrent", "config": MockConfig()} + ) + + strategies = await asyncio.gather( + resolver.resolve_retry_strategy(properties=properties), + resolver.resolve_retry_strategy(properties=properties), + resolver.resolve_retry_strategy(properties=properties), + ) + + assert strategies[0] is strategies[1] + assert strategies[1] is strategies[2] + + +@pytest.mark.asyncio +async def test_caching_retry_strategy_resolver_caches_by_options() -> None: + resolver = CachingRetryStrategyResolver() + + config1 = MockConfig(retry_options=RetryStrategyOptions(max_attempts=3)) + config2 = MockConfig(retry_options=RetryStrategyOptions(max_attempts=5)) + + properties1 = TypedProperties({CLIENT_ID.key: "test-client", "config": config1}) + properties2 = TypedProperties({CLIENT_ID.key: "test-client", "config": config2}) + + strategy1 = await resolver.resolve_retry_strategy(properties=properties1) + strategy2 = await resolver.resolve_retry_strategy(properties=properties2) + + assert strategy1 is not strategy2 + assert strategy1.max_attempts == 3 + assert strategy2.max_attempts == 5 + + +@pytest.mark.asyncio +async def test_caching_retry_strategy_resolver_requires_client_id() -> None: + resolver = CachingRetryStrategyResolver() + properties = TypedProperties({}) + + with pytest.raises(ValueError, match=CLIENT_ID.key): + await resolver.resolve_retry_strategy(properties=properties) + + +def test_caching_retry_strategy_resolver_survives_deepcopy() -> None: + from copy import deepcopy + + resolver = CachingRetryStrategyResolver() + resolver_copy = deepcopy(resolver) + + assert resolver is resolver_copy