Skip to content

Commit ac30b2a

Browse files
author
Andrzej Pijanowski
committed
feat: retry with back-off logic for Redis
1 parent c1a7bc1 commit ac30b2a

File tree

4 files changed

+186
-40
lines changed

4 files changed

+186
-40
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
99

1010
### Added
1111

12+
- Added retry with back-off logic for Redis related functions.
13+
1214
### Changed
1315

1416
### Fixed

stac_fastapi/core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"jsonschema~=4.0.0",
4646
"slowapi~=0.1.9",
4747
"redis==6.4.0",
48+
"retry==0.9.2",
4849
]
4950

5051
[project.urls]

stac_fastapi/core/stac_fastapi/core/redis_utils.py

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import json
44
import logging
5-
from typing import List, Optional, Tuple
5+
from functools import wraps
6+
from typing import Callable, List, Optional, Tuple
67
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
78

89
from pydantic import Field, field_validator
910
from pydantic_settings import BaseSettings
1011
from redis import asyncio as aioredis
1112
from redis.asyncio.sentinel import Sentinel
13+
from redis.exceptions import ConnectionError as RedisConnectionError
14+
from redis.exceptions import TimeoutError as RedisTimeoutError
15+
from retry import retry # type: ignore
1216

1317
logger = logging.getLogger(__name__)
1418

@@ -143,65 +147,104 @@ def validate_self_link_ttl_standalone(cls, v: int) -> int:
143147
return v
144148

145149

150+
class RedisRetrySettings(BaseSettings):
151+
"""Configuration for Redis retry wrapper."""
152+
153+
redis_query_retries_num: int = Field(
154+
default=3, alias="REDIS_QUERY_RETRIES_NUM", gt=0
155+
)
156+
redis_query_initial_delay: float = Field(
157+
default=1.0, alias="REDIS_QUERY_INITIAL_DELAY", gt=0
158+
)
159+
redis_query_backoff: float = Field(default=2.0, alias="REDIS_QUERY_BACKOFF", gt=1)
160+
161+
146162
# Configure only one Redis configuration
147163
sentinel_settings = RedisSentinelSettings()
148164
standalone_settings = RedisSettings()
165+
retry_settings = RedisRetrySettings()
149166

150167

151-
async def connect_redis() -> Optional[aioredis.Redis]:
168+
def redis_retry(func: Callable) -> Callable:
169+
"""Wrap function in retry with back-off logic."""
170+
171+
@wraps(func)
172+
@retry(
173+
exceptions=(RedisConnectionError, RedisTimeoutError),
174+
tries=retry_settings.redis_query_retries_num,
175+
delay=retry_settings.redis_query_initial_delay,
176+
backoff=retry_settings.redis_query_backoff,
177+
logger=logger,
178+
)
179+
async def wrapper(*args, **kwargs):
180+
return await func(*args, **kwargs)
181+
182+
return wrapper
183+
184+
185+
@redis_retry
186+
async def _connect_redis_internal() -> Optional[aioredis.Redis]:
152187
"""Return a Redis connection Redis or Redis Sentinel."""
153-
try:
154-
if sentinel_settings.REDIS_SENTINEL_HOSTS:
155-
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
156-
sentinel = Sentinel(
157-
sentinel_nodes,
158-
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
159-
)
188+
if sentinel_settings.REDIS_SENTINEL_HOSTS:
189+
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
190+
sentinel = Sentinel(
191+
sentinel_nodes,
192+
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
193+
)
160194

161-
redis = sentinel.master_for(
162-
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
163-
db=sentinel_settings.REDIS_DB,
164-
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
165-
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
166-
client_name=sentinel_settings.REDIS_CLIENT_NAME,
167-
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
168-
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
169-
)
170-
logger.info("Connected to Redis Sentinel")
195+
redis = sentinel.master_for(
196+
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
197+
db=sentinel_settings.REDIS_DB,
198+
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
199+
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
200+
client_name=sentinel_settings.REDIS_CLIENT_NAME,
201+
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
202+
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
203+
)
204+
logger.info("Connected to Redis Sentinel")
205+
206+
elif standalone_settings.REDIS_HOST:
207+
pool = aioredis.ConnectionPool(
208+
host=standalone_settings.REDIS_HOST,
209+
port=standalone_settings.REDIS_PORT,
210+
db=standalone_settings.REDIS_DB,
211+
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
212+
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
213+
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
214+
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
215+
)
216+
redis = aioredis.Redis(
217+
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
218+
)
219+
logger.info("Connected to Redis")
220+
else:
221+
logger.warning("No Redis configuration found")
222+
return None
171223

172-
elif standalone_settings.REDIS_HOST:
173-
pool = aioredis.ConnectionPool(
174-
host=standalone_settings.REDIS_HOST,
175-
port=standalone_settings.REDIS_PORT,
176-
db=standalone_settings.REDIS_DB,
177-
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
178-
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
179-
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
180-
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
181-
)
182-
redis = aioredis.Redis(
183-
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
184-
)
185-
logger.info("Connected to Redis")
186-
else:
187-
logger.warning("No Redis configuration found")
188-
return None
224+
return redis
189225

190-
return redis
191226

227+
async def connect_redis() -> Optional[aioredis.Redis]:
228+
"""Handle Redis connection."""
229+
try:
230+
return await _connect_redis_internal()
231+
except (
232+
aioredis.ConnectionError,
233+
aioredis.TimeoutError,
234+
) as e:
235+
logger.error(f"Redis connection failed after retries: {e}")
192236
except aioredis.ConnectionError as e:
193237
logger.error(f"Redis connection error: {e}")
194238
return None
195239
except aioredis.AuthenticationError as e:
196240
logger.error(f"Redis authentication error: {e}")
197241
return None
198-
except aioredis.TimeoutError as e:
199-
logger.error(f"Redis timeout error: {e}")
200-
return None
201242
except Exception as e:
202243
logger.error(f"Failed to connect to Redis: {e}")
203244
return None
204245

246+
return None
247+
205248

206249
def get_redis_key(url: str, token: str) -> str:
207250
"""Create Redis key using URL path and token."""
@@ -230,6 +273,7 @@ def build_url_with_token(base_url: str, token: str) -> str:
230273
)
231274

232275

276+
@redis_retry
233277
async def save_prev_link(
234278
redis: aioredis.Redis, next_url: str, current_url: str, next_token: str
235279
) -> None:
@@ -243,6 +287,7 @@ async def save_prev_link(
243287
await redis.setex(key, ttl_seconds, current_url)
244288

245289

290+
@redis_retry
246291
async def get_prev_link(
247292
redis: aioredis.Redis, current_url: str, current_token: str
248293
) -> Optional[str]:

stac_fastapi/tests/redis/test_redis_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
2+
from redis.exceptions import ConnectionError as RedisConnectionError
23

4+
import stac_fastapi.core.redis_utils as redis_utils
35
from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link
46

57

@@ -46,3 +48,99 @@ async def test_redis_utils_functions():
4648
redis, "http://mywebsite.com/search", "non_existent_token"
4749
)
4850
assert non_existent is None
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_redis_retry_retries_until_success(monkeypatch):
55+
monkeypatch.setattr(
56+
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
57+
)
58+
monkeypatch.setattr(
59+
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
60+
)
61+
monkeypatch.setattr(
62+
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
63+
)
64+
65+
captured_kwargs = {}
66+
67+
def fake_retry(**kwargs):
68+
captured_kwargs.update(kwargs)
69+
70+
def decorator(func):
71+
async def wrapped(*args, **inner_kwargs):
72+
attempts = 0
73+
while True:
74+
try:
75+
attempts += 1
76+
return await func(*args, **inner_kwargs)
77+
except kwargs["exceptions"] as exc:
78+
if attempts >= kwargs["tries"]:
79+
raise exc
80+
continue
81+
82+
return wrapped
83+
84+
return decorator
85+
86+
monkeypatch.setattr(redis_utils, "retry", fake_retry)
87+
88+
call_counter = {"count": 0}
89+
90+
@redis_utils.redis_retry
91+
async def flaky() -> str:
92+
call_counter["count"] += 1
93+
if call_counter["count"] < 3:
94+
raise RedisConnectionError("transient failure")
95+
return "success"
96+
97+
result = await flaky()
98+
99+
assert result == "success"
100+
assert call_counter["count"] == 3
101+
assert (
102+
captured_kwargs["tries"] == redis_utils.retry_settings.redis_query_retries_num
103+
)
104+
assert (
105+
captured_kwargs["delay"] == redis_utils.retry_settings.redis_query_initial_delay
106+
)
107+
assert captured_kwargs["backoff"] == redis_utils.retry_settings.redis_query_backoff
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_redis_retry_raises_after_exhaustion(monkeypatch):
112+
monkeypatch.setattr(
113+
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
114+
)
115+
monkeypatch.setattr(
116+
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
117+
)
118+
monkeypatch.setattr(
119+
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
120+
)
121+
122+
def fake_retry(**kwargs):
123+
def decorator(func):
124+
async def wrapped(*args, **inner_kwargs):
125+
attempts = 0
126+
while True:
127+
try:
128+
attempts += 1
129+
return await func(*args, **inner_kwargs)
130+
except kwargs["exceptions"] as exc:
131+
if attempts >= kwargs["tries"]:
132+
raise exc
133+
continue
134+
135+
return wrapped
136+
137+
return decorator
138+
139+
monkeypatch.setattr(redis_utils, "retry", fake_retry)
140+
141+
@redis_utils.redis_retry
142+
async def always_fail() -> str:
143+
raise RedisConnectionError("pernament failure")
144+
145+
with pytest.raises(RedisConnectionError):
146+
await always_fail()

0 commit comments

Comments
 (0)