Skip to content

Commit e057aea

Browse files
authored
add a 'tag' parameter to the 'cache' decorator for future functionality (#16)
1 parent c99b3d1 commit e057aea

File tree

7 files changed

+61
-16
lines changed

7 files changed

+61
-16
lines changed

TODO.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ These below are from Issues or PRs in the original repository.
1313
- add an option to have a separate logging file for cache hits and misses?
1414
- remove creating a test Redis from `redis.py`. This should not be done in the
1515
production logic, but set up in the test logic.
16-
- add a `cache_key` or `tag` parameter to the `cache` decorator to allow for
17-
custom cache keys
1816
- remove the FakeRedis from the `_connect_` function. This should be set up in
1917
the test logic not production code.
18+
- catch invalid cache type exceptions and raise a more informative error
19+
message.

fastapi_redis_cache/cache.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""The main cache decorator code and helpers."""
22

3+
from __future__ import annotations
4+
35
import asyncio
46
from datetime import timedelta
57
from functools import partial, update_wrapper, wraps
@@ -23,14 +25,19 @@
2325

2426

2527
def cache(
26-
*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS
28+
*,
29+
expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS,
30+
tag: str | None = None,
2731
) -> Callable[..., Any]:
2832
"""Enable caching behavior for the decorated function.
2933
3034
Args:
3135
expire (Union[int, timedelta], optional): The number of seconds
3236
from now when the cached response should expire. Defaults to
3337
31,536,000 seconds (i.e., the number of seconds in one year).
38+
tag (str, optional): A tag to associate with the cached response. This
39+
can later be used to invalidate all cached responses with the same
40+
tag, or for further fine-grained cache expiry. Defaults to None.
3441
"""
3542

3643
def outer_wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -61,7 +68,7 @@ async def inner_wrapper(
6168
# if the redis client is not connected or request is not
6269
# cacheable, no caching behavior is performed.
6370
return await get_api_response_async(func, *args, **kwargs)
64-
key = redis_cache.get_cache_key(func, *args, **kwargs)
71+
key = redis_cache.get_cache_key(tag, func, *args, **kwargs)
6572
ttl, in_cache = redis_cache.check_cache(key)
6673
if in_cache:
6774
redis_cache.set_response_headers(
@@ -93,6 +100,10 @@ async def inner_wrapper(
93100
response_data = await get_api_response_async(func, *args, **kwargs)
94101
ttl = calculate_ttl(expire)
95102
cached = redis_cache.add_to_cache(key, response_data, ttl)
103+
if tag:
104+
# if tag is provided, add the key to the tag set. This should
105+
# help us search quicker for keys to invalidate.
106+
redis_cache.add_key_to_tag_set(tag, key)
96107
if cached:
97108
redis_cache.set_response_headers(
98109
response,

fastapi_redis_cache/client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,33 @@ def request_is_not_cacheable(self, request: Request) -> bool:
135135
)
136136

137137
def get_cache_key(
138-
self, func: Callable[..., Any], *args: Any, **kwargs: Any
138+
self,
139+
tag: str | None,
140+
func: Callable[..., Any],
141+
*args: Any,
142+
**kwargs: Any,
139143
) -> str:
140144
"""Return a key to use for caching the response of a function."""
141145
return get_cache_key(
142-
self.prefix, self.ignore_arg_types, func, *args, **kwargs
146+
self.prefix, tag, self.ignore_arg_types, func, *args, **kwargs
143147
)
144148

149+
def add_key_to_tag_set(self, tag: str, key: str) -> None:
150+
"""Add a key to a set of keys associated with a tag.
151+
152+
Searching for keys to invalidate is faster when they are grouped by tag
153+
as it reduces the number of keys to search through.
154+
155+
However, keys are not removed from the tag set when they expire so will
156+
need to handle possibly stale keys when invalidating.
157+
"""
158+
if self.redis:
159+
self.redis.sadd(tag, key)
160+
161+
def get_tagged_keys(self, tag: str) -> set[str]:
162+
"""Return a set of keys associated with a tag."""
163+
return self.redis.smembers(tag) if self.redis else set()
164+
145165
def check_cache(self, key: str) -> tuple[int, str]:
146166
"""Check if `key` is in the cache and return its TTL and value."""
147167
if not self.redis:

fastapi_redis_cache/key_gen.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
"""Helper functions for generating cache keys."""
22

3+
from __future__ import annotations
4+
35
from inspect import Signature, signature
46
from typing import TYPE_CHECKING, Any, Callable
57

68
from fastapi import Request, Response
79

8-
from fastapi_redis_cache.types import ArgType, SigParameters
9-
1010
if TYPE_CHECKING: # pragma: no cover
1111
from collections import OrderedDict
1212

13+
from fastapi_redis_cache.types import ArgType, SigParameters
14+
1315
ALWAYS_IGNORE_ARG_TYPES = [Response, Request]
1416

1517

1618
def get_cache_key( # noqa: D417
1719
prefix: str,
20+
tag: str | None,
1821
ignore_arg_types: list[ArgType],
1922
func: Callable[..., Any],
2023
*args: list[Any],
@@ -25,6 +28,9 @@ def get_cache_key( # noqa: D417
2528
Args:
2629
prefix (`str`): Customizable namespace value that will prefix all cache
2730
keys.
31+
tag (`str`): Customizable tag value that will be inserted into the
32+
cache key. This can be used to group related keys together or to
33+
help expire a key in other routes.
2834
ignore_arg_types (`list[ArgType]`): Each argument to the API endpoint
2935
function is used to compose the cache key by calling `str(arg)`. If
3036
there are any keys that should not be used in this way (i.e.,
@@ -42,17 +48,18 @@ def get_cache_key( # noqa: D417
4248
ignore_arg_types.extend(ALWAYS_IGNORE_ARG_TYPES)
4349
ignore_arg_types = list(set(ignore_arg_types))
4450
prefix = f"{prefix}:" if prefix else ""
51+
tag_string = f"::{tag}" if tag else ""
4552

4653
sig = signature(func)
4754
sig_params = sig.parameters
4855
func_args = get_func_args(sig, *args, **kwargs)
4956
args_str = get_args_str(sig_params, func_args, ignore_arg_types)
50-
return f"{prefix}{func.__module__}.{func.__name__}({args_str})"
57+
return f"{prefix}{func.__module__}.{func.__name__}({args_str}){tag_string}"
5158

5259

5360
def get_func_args(
5461
sig: Signature, *args: list[Any], **kwargs: dict[Any, Any]
55-
) -> "OrderedDict[str, Any]":
62+
) -> OrderedDict[str, Any]:
5663
"""Return a dict object containing name and value of function arguments."""
5764
func_args = sig.bind(*args, **kwargs)
5865
func_args.apply_defaults()
@@ -61,7 +68,7 @@ def get_func_args(
6168

6269
def get_args_str(
6370
sig_params: SigParameters,
64-
func_args: "OrderedDict[str, Any]",
71+
func_args: OrderedDict[str, Any],
6572
ignore_arg_types: list[ArgType],
6673
) -> str:
6774
"""Return a string with name and value of all args.

fastapi_redis_cache/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Define utility functions for the fastapi_redis_cache package."""
22

3+
from __future__ import annotations
4+
35
import json
46
from datetime import date, datetime
57
from decimal import Decimal
@@ -78,3 +80,8 @@ def serialize_json(json_dict: dict[str, Any]) -> str:
7880
def deserialize_json(json_str: str) -> Any: # noqa: ANN401
7981
"""Deserialize a JSON string to a dictionary."""
8082
return json.loads(json_str, object_hook=object_hook)
83+
84+
85+
def get_tag_from_key(key: str) -> str | None:
86+
"""Return the tag from the key or None if not found."""
87+
return key.split("::")[-1] if "::" in key else None

requirements-dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pynacl==1.5.0 ; python_version >= "3.9" and python_version < "4.0"
5959
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "4.0"
6060
pytest-cov==4.1.0 ; python_version >= "3.9" and python_version < "4.0"
6161
pytest-env==1.1.3 ; python_version >= "3.9" and python_version < "4.0"
62-
pytest-mock==3.12.0 ; python_version >= "3.9" and python_version < "4.0"
62+
pytest-mock==3.14.0 ; python_version >= "3.9" and python_version < "4.0"
6363
pytest-order==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
6464
pytest-randomly==3.15.0 ; python_version >= "3.9" and python_version < "4.0"
6565
pytest-reverse==1.7.0 ; python_version >= "3.9" and python_version < "4.0"
@@ -74,7 +74,7 @@ redis==5.0.3 ; python_version >= "3.9" and python_version < "4.0"
7474
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
7575
rich==13.7.1 ; python_version >= "3.9" and python_version < "4.0"
7676
rtoml==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
77-
ruff==0.3.3 ; python_version >= "3.9" and python_version < "4.0"
77+
ruff==0.3.4 ; python_version >= "3.9" and python_version < "4.0"
7878
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "4.0"
7979
simple-toml-settings==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
8080
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"

tests/live_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def cache_never_expire(
5050

5151

5252
@app.get("/cache_expires")
53-
@cache(expire=timedelta(seconds=5))
53+
@cache(expire=timedelta(seconds=5), tag="test_tag_1")
5454
async def cache_expires() -> dict[str, Union[bool, str]]:
5555
"""Route where the cache expires after 5 seconds."""
5656
return {
@@ -60,7 +60,7 @@ async def cache_expires() -> dict[str, Union[bool, str]]:
6060

6161

6262
@app.get("/cache_json_encoder")
63-
@cache()
63+
@cache(tag="test_tag_1")
6464
def cache_json_encoder() -> (
6565
dict[str, Union[bool, str, datetime, date, Decimal]]
6666
):
@@ -77,7 +77,7 @@ def cache_json_encoder() -> (
7777

7878

7979
@app.get("/cache_one_hour")
80-
@cache_one_hour()
80+
@cache_one_hour(tag="test_tag_2")
8181
def partial_cache_one_hour(response: Response) -> dict[str, Union[bool, str]]:
8282
"""Route where the cache expires after one hour."""
8383
return {

0 commit comments

Comments
 (0)