Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _function_thread(core, key, func, args, kwds):
core.set_entry(key, func_res)
except BaseException as exc:
print(f"Function call failed with the following exception:\n{exc}")
finally:
core.mark_entry_not_calculated(key)


def _calc_entry(
Expand Down Expand Up @@ -358,11 +360,7 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
)
nonneg_max_age = False
else:
max_allowed_age = (
min(_stale_after, max_age)
if max_age is not None
else _stale_after
)
max_allowed_age = min(_stale_after, max_age)
# note: if max_age < 0, we always consider a value stale
if nonneg_max_age and (now - entry.time <= max_allowed_age):
_print("And it is fresh!")
Expand All @@ -380,12 +378,9 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
if _next_time:
_print("Async calc and return stale")
core.mark_entry_being_calculated(key)
try:
_get_executor().submit(
_function_thread, core, key, func, args, kwds
)
finally:
core.mark_entry_not_calculated(key)
_get_executor().submit(
_function_thread, core, key, func, args, kwds
)
return entry.value
_print("Calling decorated function and waiting")
return _calc_entry(core, key, func, args, kwds, _print)
Expand Down
11 changes: 8 additions & 3 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ class RecalculationNeeded(Exception):


def _get_func_str(func: Callable) -> str:
return f".{func.__module__}.{func.__name__}"
"""Return a string identifier for the function (module + name).

We accept Any here because static analysis can't always prove that the
runtime object will have __module__ and __name__, but at runtime the
decorated functions always do.

"""
return f".{func.__module__}.{func.__name__}"

class _BaseCore:
__metaclass__ = abc.ABCMeta

class _BaseCore(metaclass=abc.ABCMeta):
def __init__(
self,
hash_func: Optional[HashFunc],
Expand Down
82 changes: 63 additions & 19 deletions src/cachier/cores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,34 +84,70 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
if not cached_data:
return key, None

# helper to fetch field regardless of bytes/str keys
def _raw(field: str):
# try bytes key first, then str key
bkey = field.encode("utf-8")
if bkey in cached_data:
return cached_data[bkey]
return cached_data.get(field)

# Deserialize the value
value = None
if cached_data.get(b"value"):
value = pickle.loads(cached_data[b"value"])
raw_value = _raw("value")
if raw_value is not None:
try:
if isinstance(raw_value, bytes):
value = pickle.loads(raw_value)
elif isinstance(raw_value, str):
# try to recover by encoding; prefer utf-8 but fall back
# to latin-1 in case raw binary was coerced to str
try:
value = pickle.loads(raw_value.encode("utf-8"))
except Exception:
value = pickle.loads(raw_value.encode("latin-1"))
Copy link

Copilot AI Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bare except Exception: clause should be more specific. Consider catching UnicodeDecodeError or pickle.PickleError specifically, or at minimum log the exception for debugging purposes.

Copilot uses AI. Check for mistakes.
else:
# unexpected type; attempt pickle.loads directly
try:
value = pickle.loads(raw_value)
except Exception:
value = None
except Exception as exc:
warnings.warn(
f"Redis value deserialization failed: {exc}",
stacklevel=2,
)

# Parse timestamp
timestamp_str = cached_data.get(b"timestamp", b"").decode("utf-8")
raw_ts = _raw("timestamp") or b""
if isinstance(raw_ts, bytes):
try:
timestamp_str = raw_ts.decode("utf-8")
except Exception:
timestamp_str = raw_ts.decode("latin-1", errors="ignore")
else:
timestamp_str = str(raw_ts)
timestamp = (
datetime.fromisoformat(timestamp_str)
if timestamp_str
else datetime.now()
)

# Parse boolean fields
stale = (
cached_data.get(b"stale", b"false").decode("utf-8").lower()
== "true"
)
processing = (
cached_data.get(b"processing", b"false")
.decode("utf-8")
.lower()
== "true"
)
completed = (
cached_data.get(b"completed", b"false").decode("utf-8").lower()
== "true"
)
def _bool_field(name: str) -> bool:
raw = _raw(name) or b"false"
if isinstance(raw, bytes):
try:
s = raw.decode("utf-8")
except Exception:
s = raw.decode("latin-1", errors="ignore")
else:
s = str(raw)
return s.lower() == "true"

stale = _bool_field("stale")
processing = _bool_field("processing")
completed = _bool_field("completed")

entry = CacheEntry(
value=value,
Expand All @@ -126,9 +162,9 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
return key, None

def set_entry(self, key: str, func_res: Any) -> bool:
"""Map the given result to the given key in Redis."""
if not self._should_store(func_res):
return False
"""Map the given result to the given key in Redis."""
redis_client = self._resolve_redis_client()
redis_key = self._get_redis_key(key)

Expand Down Expand Up @@ -242,8 +278,16 @@ def delete_stale_entries(self, stale_after: timedelta) -> None:
ts = redis_client.hget(key, "timestamp")
if ts is None:
continue
# ts may be bytes or str depending on client configuration
if isinstance(ts, bytes):
try:
ts_s = ts.decode("utf-8")
except Exception:
ts_s = ts.decode("latin-1", errors="ignore")
else:
ts_s = str(ts)
try:
ts_val = datetime.fromisoformat(ts.decode("utf-8"))
ts_val = datetime.fromisoformat(ts_s)
except Exception as exc:
warnings.warn(
f"Redis timestamp parse failed: {exc}", stacklevel=2
Expand Down
Loading