Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- Added retry with back-off logic for Redis related functions. [#528](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/528)

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions stac_fastapi/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"jsonschema~=4.0.0",
"slowapi~=0.1.9",
"redis==6.4.0",
"retry==0.9.2",
]

[project.urls]
Expand Down
184 changes: 97 additions & 87 deletions stac_fastapi/core/stac_fastapi/core/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,39 @@

import json
import logging
from typing import List, Optional, Tuple
from functools import wraps
from typing import Callable, List, Optional, Tuple, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from redis import asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import TimeoutError as RedisTimeoutError
from retry import retry # type: ignore

logger = logging.getLogger(__name__)


class RedisSentinelSettings(BaseSettings):
"""Configuration for connecting to Redis Sentinel."""
class RedisCommonSettings(BaseSettings):
"""Common configuration for Redis Sentinel and Redis Standalone."""

REDIS_SENTINEL_HOSTS: str = ""
REDIS_SENTINEL_PORTS: str = "26379"
REDIS_SENTINEL_MASTER_NAME: str = "master"
REDIS_DB: int = 15

REDIS_MAX_CONNECTIONS: Optional[int] = None
REDIS_RETRY_TIMEOUT: bool = True
REDIS_DECODE_RESPONSES: bool = True
REDIS_CLIENT_NAME: str = "stac-fastapi-app"
REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0)
REDIS_SELF_LINK_TTL: int = 1800

REDIS_QUERY_RETRIES_NUM: int = Field(default=3, gt=0)
REDIS_QUERY_INITIAL_DELAY: float = Field(default=1.0, gt=0)
REDIS_QUERY_BACKOFF: float = Field(default=2.0, gt=1)

@field_validator("REDIS_DB")
@classmethod
def validate_db_sentinel(cls, v: int) -> int:
def validate_db(cls, v: int) -> int:
"""Validate REDIS_DB is not negative integer."""
if v < 0:
raise ValueError("REDIS_DB must be a positive integer")
Expand All @@ -46,12 +50,20 @@ def validate_max_connections(cls, v):

@field_validator("REDIS_SELF_LINK_TTL")
@classmethod
def validate_self_link_ttl_sentinel(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is not a negative integer."""
def validate_self_link_ttl(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is negative."""
if v < 0:
raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer")
return v


class RedisSentinelSettings(RedisCommonSettings):
"""Configuration for connecting to Redis Sentinel."""

REDIS_SENTINEL_HOSTS: str = ""
REDIS_SENTINEL_PORTS: str = "26379"
REDIS_SENTINEL_MASTER_NAME: str = "master"

def get_sentinel_hosts(self) -> List[str]:
"""Parse Redis Sentinel hosts from string to list."""
if not self.REDIS_SENTINEL_HOSTS:
Expand Down Expand Up @@ -96,19 +108,11 @@ def get_sentinel_nodes(self) -> List[Tuple[str, int]]:
return [(str(host), int(port)) for host, port in zip(hosts, ports)]


class RedisSettings(BaseSettings):
class RedisSettings(RedisCommonSettings):
"""Configuration for connecting Redis."""

REDIS_HOST: str = ""
REDIS_PORT: int = 6379
REDIS_DB: int = 15

REDIS_MAX_CONNECTIONS: Optional[int] = None
REDIS_RETRY_TIMEOUT: bool = True
REDIS_DECODE_RESPONSES: bool = True
REDIS_CLIENT_NAME: str = "stac-fastapi-app"
REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0)
REDIS_SELF_LINK_TTL: int = 1800

@field_validator("REDIS_PORT")
@classmethod
Expand All @@ -118,89 +122,93 @@ def validate_port_standalone(cls, v: int) -> int:
raise ValueError("REDIS_PORT must be a positive integer")
return v

@field_validator("REDIS_DB")
@classmethod
def validate_db_standalone(cls, v: int) -> int:
"""Validate REDIS_DB is not a negative integer."""
if v < 0:
raise ValueError("REDIS_DB must be a positive integer")
return v

@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def validate_max_connections(cls, v):
"""Handle empty/None values for REDIS_MAX_CONNECTIONS."""
if v in ["", "null", "Null", "NULL", "none", "None", "NONE", None]:
return None
return v

@field_validator("REDIS_SELF_LINK_TTL")
@classmethod
def validate_self_link_ttl_standalone(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is negative."""
if v < 0:
raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer")
return v


# Configure only one Redis configuration
sentinel_settings = RedisSentinelSettings()
standalone_settings = RedisSettings()
settings: RedisCommonSettings = cast(
RedisCommonSettings,
sentinel_settings if sentinel_settings.REDIS_SENTINEL_HOSTS else RedisSettings(),
)


def redis_retry(func: Callable) -> Callable:
"""Retry with back-off decorator for Redis connections."""

@wraps(func)
@retry(
exceptions=(RedisConnectionError, RedisTimeoutError),
tries=settings.REDIS_QUERY_RETRIES_NUM,
delay=settings.REDIS_QUERY_INITIAL_DELAY,
backoff=settings.REDIS_QUERY_BACKOFF,
logger=logger,
)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)

return wrapper

async def connect_redis() -> Optional[aioredis.Redis]:

@redis_retry
async def _connect_redis_internal() -> Optional[aioredis.Redis]:
"""Return a Redis connection Redis or Redis Sentinel."""
try:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=settings.REDIS_DECODE_RESPONSES,
)

redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None
redis = sentinel.master_for(
service_name=settings.REDIS_SENTINEL_MASTER_NAME,
db=settings.REDIS_DB,
decode_responses=settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=settings.REDIS_RETRY_TIMEOUT,
client_name=settings.REDIS_CLIENT_NAME,
max_connections=settings.REDIS_MAX_CONNECTIONS,
health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
max_connections=settings.REDIS_MAX_CONNECTIONS,
decode_responses=settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=settings.REDIS_RETRY_TIMEOUT,
health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None

return redis

return redis

async def connect_redis() -> Optional[aioredis.Redis]:
"""Handle Redis connection."""
try:
return await _connect_redis_internal()
except (
aioredis.ConnectionError,
aioredis.TimeoutError,
) as e:
logger.error(f"Redis connection failed after retries: {e}")
except aioredis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return None
except aioredis.AuthenticationError as e:
logger.error(f"Redis authentication error: {e}")
return None
except aioredis.TimeoutError as e:
logger.error(f"Redis timeout error: {e}")
return None
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return None
return None


def get_redis_key(url: str, token: str) -> str:
Expand Down Expand Up @@ -230,19 +238,21 @@ def build_url_with_token(base_url: str, token: str) -> str:
)


@redis_retry
async def save_prev_link(
redis: aioredis.Redis, next_url: str, current_url: str, next_token: str
) -> None:
"""Save the current page as the previous link for the next URL."""
if next_url and next_token:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
ttl_seconds = sentinel_settings.REDIS_SELF_LINK_TTL
elif standalone_settings.REDIS_HOST:
ttl_seconds = standalone_settings.REDIS_SELF_LINK_TTL
ttl_seconds = settings.REDIS_SELF_LINK_TTL
elif settings.REDIS_HOST:
ttl_seconds = settings.REDIS_SELF_LINK_TTL
key = get_redis_key(next_url, next_token)
await redis.setex(key, ttl_seconds, current_url)


@redis_retry
async def get_prev_link(
redis: aioredis.Redis, current_url: str, current_token: str
) -> Optional[str]:
Expand Down
90 changes: 90 additions & 0 deletions stac_fastapi/tests/redis/test_redis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from redis.exceptions import ConnectionError as RedisConnectionError

import stac_fastapi.core.redis_utils as redis_utils
from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link


Expand Down Expand Up @@ -46,3 +48,91 @@ async def test_redis_utils_functions():
redis, "http://mywebsite.com/search", "non_existent_token"
)
assert non_existent is None


@pytest.mark.asyncio
async def test_redis_retry_retries_until_success(monkeypatch):
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False
)
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False
)
monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False)

captured_kwargs = {}

def fake_retry(**kwargs):
captured_kwargs.update(kwargs)

def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

call_counter = {"count": 0}

@redis_utils.redis_retry
async def flaky() -> str:
call_counter["count"] += 1
if call_counter["count"] < 3:
raise RedisConnectionError("transient failure")
return "success"

result = await flaky()

assert result == "success"
assert call_counter["count"] == 3
assert captured_kwargs["tries"] == redis_utils.settings.REDIS_QUERY_RETRIES_NUM
assert captured_kwargs["delay"] == redis_utils.settings.REDIS_QUERY_INITIAL_DELAY
assert captured_kwargs["backoff"] == redis_utils.settings.REDIS_QUERY_BACKOFF


@pytest.mark.asyncio
async def test_redis_retry_raises_after_exhaustion(monkeypatch):
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False
)
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False
)
monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False)

def fake_retry(**kwargs):
def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

@redis_utils.redis_retry
async def always_fail() -> str:
raise RedisConnectionError("pernament failure")

with pytest.raises(RedisConnectionError):
await always_fail()