Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- Added retry with back-off logic for Redis related functions. [#528](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/528)

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions stac_fastapi/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"jsonschema~=4.0.0",
"slowapi~=0.1.9",
"redis==6.4.0",
"retry==0.9.2",
]

[project.urls]
Expand Down
125 changes: 85 additions & 40 deletions stac_fastapi/core/stac_fastapi/core/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import json
import logging
from typing import List, Optional, Tuple
from functools import wraps
from typing import Callable, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from redis import asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import TimeoutError as RedisTimeoutError
from retry import retry # type: ignore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,65 +147,104 @@ def validate_self_link_ttl_standalone(cls, v: int) -> int:
return v


class RedisRetrySettings(BaseSettings):
"""Configuration for Redis retry wrapper."""

redis_query_retries_num: int = Field(
default=3, alias="REDIS_QUERY_RETRIES_NUM", gt=0
)
redis_query_initial_delay: float = Field(
default=1.0, alias="REDIS_QUERY_INITIAL_DELAY", gt=0
)
redis_query_backoff: float = Field(default=2.0, alias="REDIS_QUERY_BACKOFF", gt=1)


# Configure only one Redis configuration
sentinel_settings = RedisSentinelSettings()
standalone_settings = RedisSettings()
retry_settings = RedisRetrySettings()


async def connect_redis() -> Optional[aioredis.Redis]:
def redis_retry(func: Callable) -> Callable:
"""Wrap function in retry with back-off logic."""

@wraps(func)
@retry(
exceptions=(RedisConnectionError, RedisTimeoutError),
tries=retry_settings.redis_query_retries_num,
delay=retry_settings.redis_query_initial_delay,
backoff=retry_settings.redis_query_backoff,
logger=logger,
)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)

return wrapper


@redis_retry
async def _connect_redis_internal() -> Optional[aioredis.Redis]:
"""Return a Redis connection Redis or Redis Sentinel."""
try:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)

redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")
redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None
return redis

return redis

async def connect_redis() -> Optional[aioredis.Redis]:
"""Handle Redis connection."""
try:
return await _connect_redis_internal()
except (
aioredis.ConnectionError,
aioredis.TimeoutError,
) as e:
logger.error(f"Redis connection failed after retries: {e}")
except aioredis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return None
except aioredis.AuthenticationError as e:
logger.error(f"Redis authentication error: {e}")
return None
except aioredis.TimeoutError as e:
logger.error(f"Redis timeout error: {e}")
return None
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return None

return None


def get_redis_key(url: str, token: str) -> str:
"""Create Redis key using URL path and token."""
Expand Down Expand Up @@ -230,6 +273,7 @@ def build_url_with_token(base_url: str, token: str) -> str:
)


@redis_retry
async def save_prev_link(
redis: aioredis.Redis, next_url: str, current_url: str, next_token: str
) -> None:
Expand All @@ -243,6 +287,7 @@ async def save_prev_link(
await redis.setex(key, ttl_seconds, current_url)


@redis_retry
async def get_prev_link(
redis: aioredis.Redis, current_url: str, current_token: str
) -> Optional[str]:
Expand Down
98 changes: 98 additions & 0 deletions stac_fastapi/tests/redis/test_redis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from redis.exceptions import ConnectionError as RedisConnectionError

import stac_fastapi.core.redis_utils as redis_utils
from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link


Expand Down Expand Up @@ -46,3 +48,99 @@ async def test_redis_utils_functions():
redis, "http://mywebsite.com/search", "non_existent_token"
)
assert non_existent is None


@pytest.mark.asyncio
async def test_redis_retry_retries_until_success(monkeypatch):
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
)

captured_kwargs = {}

def fake_retry(**kwargs):
captured_kwargs.update(kwargs)

def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

call_counter = {"count": 0}

@redis_utils.redis_retry
async def flaky() -> str:
call_counter["count"] += 1
if call_counter["count"] < 3:
raise RedisConnectionError("transient failure")
return "success"

result = await flaky()

assert result == "success"
assert call_counter["count"] == 3
assert (
captured_kwargs["tries"] == redis_utils.retry_settings.redis_query_retries_num
)
assert (
captured_kwargs["delay"] == redis_utils.retry_settings.redis_query_initial_delay
)
assert captured_kwargs["backoff"] == redis_utils.retry_settings.redis_query_backoff


@pytest.mark.asyncio
async def test_redis_retry_raises_after_exhaustion(monkeypatch):
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
)

def fake_retry(**kwargs):
def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

@redis_utils.redis_retry
async def always_fail() -> str:
raise RedisConnectionError("pernament failure")

with pytest.raises(RedisConnectionError):
await always_fail()