diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cd3b7345..e66e2b972 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added +- Added retry with back-off logic for Redis related functions. [#528](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/528) - Added nanosecond precision datetime filtering that ensures nanosecond precision support in filtering by datetime. This is configured via the `USE_DATETIME_NANOS` environment variable, while maintaining microseconds compatibility for datetime precision. [#529](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/529) ### Changed diff --git a/stac_fastapi/core/pyproject.toml b/stac_fastapi/core/pyproject.toml index 5498956c9..4e6a347b2 100644 --- a/stac_fastapi/core/pyproject.toml +++ b/stac_fastapi/core/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "jsonschema~=4.0.0", "slowapi~=0.1.9", "redis==6.4.0", + "retry==0.9.2", ] [project.urls] diff --git a/stac_fastapi/core/stac_fastapi/core/redis_utils.py b/stac_fastapi/core/stac_fastapi/core/redis_utils.py index f1e3fe74f..09424803d 100644 --- a/stac_fastapi/core/stac_fastapi/core/redis_utils.py +++ b/stac_fastapi/core/stac_fastapi/core/redis_utils.py @@ -2,25 +2,25 @@ import json import logging -from typing import List, Optional, Tuple +from functools import wraps +from typing import Callable, List, Optional, Tuple, cast from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import Field, field_validator from pydantic_settings import BaseSettings from redis import asyncio as aioredis from redis.asyncio.sentinel import Sentinel +from redis.exceptions import ConnectionError as RedisConnectionError +from redis.exceptions import TimeoutError as RedisTimeoutError +from retry import retry # type: ignore logger = logging.getLogger(__name__) -class RedisSentinelSettings(BaseSettings): - """Configuration for connecting to Redis Sentinel.""" +class RedisCommonSettings(BaseSettings): + """Common configuration for Redis Sentinel and Redis Standalone.""" - REDIS_SENTINEL_HOSTS: str = "" - REDIS_SENTINEL_PORTS: str = "26379" - REDIS_SENTINEL_MASTER_NAME: str = "master" REDIS_DB: int = 15 - REDIS_MAX_CONNECTIONS: Optional[int] = None REDIS_RETRY_TIMEOUT: bool = True REDIS_DECODE_RESPONSES: bool = True @@ -28,9 +28,13 @@ class RedisSentinelSettings(BaseSettings): REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0) REDIS_SELF_LINK_TTL: int = 1800 + REDIS_QUERY_RETRIES_NUM: int = Field(default=3, gt=0) + REDIS_QUERY_INITIAL_DELAY: float = Field(default=1.0, gt=0) + REDIS_QUERY_BACKOFF: float = Field(default=2.0, gt=1) + @field_validator("REDIS_DB") @classmethod - def validate_db_sentinel(cls, v: int) -> int: + def validate_db(cls, v: int) -> int: """Validate REDIS_DB is not negative integer.""" if v < 0: raise ValueError("REDIS_DB must be a positive integer") @@ -46,12 +50,20 @@ def validate_max_connections(cls, v): @field_validator("REDIS_SELF_LINK_TTL") @classmethod - def validate_self_link_ttl_sentinel(cls, v: int) -> int: - """Validate REDIS_SELF_LINK_TTL is not a negative integer.""" + def validate_self_link_ttl(cls, v: int) -> int: + """Validate REDIS_SELF_LINK_TTL is negative.""" if v < 0: raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer") return v + +class RedisSentinelSettings(RedisCommonSettings): + """Configuration for connecting to Redis Sentinel.""" + + REDIS_SENTINEL_HOSTS: str = "" + REDIS_SENTINEL_PORTS: str = "26379" + REDIS_SENTINEL_MASTER_NAME: str = "master" + def get_sentinel_hosts(self) -> List[str]: """Parse Redis Sentinel hosts from string to list.""" if not self.REDIS_SENTINEL_HOSTS: @@ -96,19 +108,11 @@ def get_sentinel_nodes(self) -> List[Tuple[str, int]]: return [(str(host), int(port)) for host, port in zip(hosts, ports)] -class RedisSettings(BaseSettings): +class RedisSettings(RedisCommonSettings): """Configuration for connecting Redis.""" REDIS_HOST: str = "" REDIS_PORT: int = 6379 - REDIS_DB: int = 15 - - REDIS_MAX_CONNECTIONS: Optional[int] = None - REDIS_RETRY_TIMEOUT: bool = True - REDIS_DECODE_RESPONSES: bool = True - REDIS_CLIENT_NAME: str = "stac-fastapi-app" - REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0) - REDIS_SELF_LINK_TTL: int = 1800 @field_validator("REDIS_PORT") @classmethod @@ -118,89 +122,93 @@ def validate_port_standalone(cls, v: int) -> int: raise ValueError("REDIS_PORT must be a positive integer") return v - @field_validator("REDIS_DB") - @classmethod - def validate_db_standalone(cls, v: int) -> int: - """Validate REDIS_DB is not a negative integer.""" - if v < 0: - raise ValueError("REDIS_DB must be a positive integer") - return v - - @field_validator("REDIS_MAX_CONNECTIONS", mode="before") - @classmethod - def validate_max_connections(cls, v): - """Handle empty/None values for REDIS_MAX_CONNECTIONS.""" - if v in ["", "null", "Null", "NULL", "none", "None", "NONE", None]: - return None - return v - - @field_validator("REDIS_SELF_LINK_TTL") - @classmethod - def validate_self_link_ttl_standalone(cls, v: int) -> int: - """Validate REDIS_SELF_LINK_TTL is negative.""" - if v < 0: - raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer") - return v - # Configure only one Redis configuration sentinel_settings = RedisSentinelSettings() -standalone_settings = RedisSettings() +settings: RedisCommonSettings = cast( + RedisCommonSettings, + sentinel_settings if sentinel_settings.REDIS_SENTINEL_HOSTS else RedisSettings(), +) + + +def redis_retry(func: Callable) -> Callable: + """Retry with back-off decorator for Redis connections.""" + + @wraps(func) + @retry( + exceptions=(RedisConnectionError, RedisTimeoutError), + tries=settings.REDIS_QUERY_RETRIES_NUM, + delay=settings.REDIS_QUERY_INITIAL_DELAY, + backoff=settings.REDIS_QUERY_BACKOFF, + logger=logger, + ) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + return wrapper -async def connect_redis() -> Optional[aioredis.Redis]: + +@redis_retry +async def _connect_redis_internal() -> Optional[aioredis.Redis]: """Return a Redis connection Redis or Redis Sentinel.""" - try: - if sentinel_settings.REDIS_SENTINEL_HOSTS: - sentinel_nodes = sentinel_settings.get_sentinel_nodes() - sentinel = Sentinel( - sentinel_nodes, - decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES, - ) + if sentinel_settings.REDIS_SENTINEL_HOSTS: + sentinel_nodes = settings.get_sentinel_nodes() + sentinel = Sentinel( + sentinel_nodes, + decode_responses=settings.REDIS_DECODE_RESPONSES, + ) - redis = sentinel.master_for( - service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME, - db=sentinel_settings.REDIS_DB, - decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES, - retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT, - client_name=sentinel_settings.REDIS_CLIENT_NAME, - max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS, - health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL, - ) - logger.info("Connected to Redis Sentinel") - - elif standalone_settings.REDIS_HOST: - pool = aioredis.ConnectionPool( - host=standalone_settings.REDIS_HOST, - port=standalone_settings.REDIS_PORT, - db=standalone_settings.REDIS_DB, - max_connections=standalone_settings.REDIS_MAX_CONNECTIONS, - decode_responses=standalone_settings.REDIS_DECODE_RESPONSES, - retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT, - health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL, - ) - redis = aioredis.Redis( - connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME - ) - logger.info("Connected to Redis") - else: - logger.warning("No Redis configuration found") - return None + redis = sentinel.master_for( + service_name=settings.REDIS_SENTINEL_MASTER_NAME, + db=settings.REDIS_DB, + decode_responses=settings.REDIS_DECODE_RESPONSES, + retry_on_timeout=settings.REDIS_RETRY_TIMEOUT, + client_name=settings.REDIS_CLIENT_NAME, + max_connections=settings.REDIS_MAX_CONNECTIONS, + health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL, + ) + logger.info("Connected to Redis Sentinel") + + elif settings.REDIS_HOST: + pool = aioredis.ConnectionPool( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + max_connections=settings.REDIS_MAX_CONNECTIONS, + decode_responses=settings.REDIS_DECODE_RESPONSES, + retry_on_timeout=settings.REDIS_RETRY_TIMEOUT, + health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL, + ) + redis = aioredis.Redis( + connection_pool=pool, client_name=settings.REDIS_CLIENT_NAME + ) + logger.info("Connected to Redis") + else: + logger.warning("No Redis configuration found") + return None + + return redis - return redis +async def connect_redis() -> Optional[aioredis.Redis]: + """Handle Redis connection.""" + try: + return await _connect_redis_internal() + except ( + aioredis.ConnectionError, + aioredis.TimeoutError, + ) as e: + logger.error(f"Redis connection failed after retries: {e}") except aioredis.ConnectionError as e: logger.error(f"Redis connection error: {e}") return None except aioredis.AuthenticationError as e: logger.error(f"Redis authentication error: {e}") return None - except aioredis.TimeoutError as e: - logger.error(f"Redis timeout error: {e}") - return None except Exception as e: logger.error(f"Failed to connect to Redis: {e}") return None + return None def get_redis_key(url: str, token: str) -> str: @@ -230,19 +238,21 @@ def build_url_with_token(base_url: str, token: str) -> str: ) +@redis_retry async def save_prev_link( redis: aioredis.Redis, next_url: str, current_url: str, next_token: str ) -> None: """Save the current page as the previous link for the next URL.""" if next_url and next_token: if sentinel_settings.REDIS_SENTINEL_HOSTS: - ttl_seconds = sentinel_settings.REDIS_SELF_LINK_TTL - elif standalone_settings.REDIS_HOST: - ttl_seconds = standalone_settings.REDIS_SELF_LINK_TTL + ttl_seconds = settings.REDIS_SELF_LINK_TTL + elif settings.REDIS_HOST: + ttl_seconds = settings.REDIS_SELF_LINK_TTL key = get_redis_key(next_url, next_token) await redis.setex(key, ttl_seconds, current_url) +@redis_retry async def get_prev_link( redis: aioredis.Redis, current_url: str, current_token: str ) -> Optional[str]: diff --git a/stac_fastapi/tests/redis/test_redis_utils.py b/stac_fastapi/tests/redis/test_redis_utils.py index 404f59a26..304cc0169 100644 --- a/stac_fastapi/tests/redis/test_redis_utils.py +++ b/stac_fastapi/tests/redis/test_redis_utils.py @@ -1,5 +1,7 @@ import pytest +from redis.exceptions import ConnectionError as RedisConnectionError +import stac_fastapi.core.redis_utils as redis_utils from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link @@ -46,3 +48,91 @@ async def test_redis_utils_functions(): redis, "http://mywebsite.com/search", "non_existent_token" ) assert non_existent is None + + +@pytest.mark.asyncio +async def test_redis_retry_retries_until_success(monkeypatch): + monkeypatch.setattr( + redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False + ) + monkeypatch.setattr( + redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False + ) + monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False) + + captured_kwargs = {} + + def fake_retry(**kwargs): + captured_kwargs.update(kwargs) + + def decorator(func): + async def wrapped(*args, **inner_kwargs): + attempts = 0 + while True: + try: + attempts += 1 + return await func(*args, **inner_kwargs) + except kwargs["exceptions"] as exc: + if attempts >= kwargs["tries"]: + raise exc + continue + + return wrapped + + return decorator + + monkeypatch.setattr(redis_utils, "retry", fake_retry) + + call_counter = {"count": 0} + + @redis_utils.redis_retry + async def flaky() -> str: + call_counter["count"] += 1 + if call_counter["count"] < 3: + raise RedisConnectionError("transient failure") + return "success" + + result = await flaky() + + assert result == "success" + assert call_counter["count"] == 3 + assert captured_kwargs["tries"] == redis_utils.settings.REDIS_QUERY_RETRIES_NUM + assert captured_kwargs["delay"] == redis_utils.settings.REDIS_QUERY_INITIAL_DELAY + assert captured_kwargs["backoff"] == redis_utils.settings.REDIS_QUERY_BACKOFF + + +@pytest.mark.asyncio +async def test_redis_retry_raises_after_exhaustion(monkeypatch): + monkeypatch.setattr( + redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False + ) + monkeypatch.setattr( + redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False + ) + monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False) + + def fake_retry(**kwargs): + def decorator(func): + async def wrapped(*args, **inner_kwargs): + attempts = 0 + while True: + try: + attempts += 1 + return await func(*args, **inner_kwargs) + except kwargs["exceptions"] as exc: + if attempts >= kwargs["tries"]: + raise exc + continue + + return wrapped + + return decorator + + monkeypatch.setattr(redis_utils, "retry", fake_retry) + + @redis_utils.redis_retry + async def always_fail() -> str: + raise RedisConnectionError("pernament failure") + + with pytest.raises(RedisConnectionError): + await always_fail()