diff --git a/pyproject.toml b/pyproject.toml index bdf8309ad4..acf3a3443e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ dependencies = [ "PyYAML>=6.0", "aiohttp", + "cachetools>=5.5.0", # for prompt caching "fastapi>=0.115.0,<1.0", # server "fire", # for MCP in LLS client "httpx", @@ -37,6 +38,7 @@ dependencies = [ "python-dotenv", "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support. "pydantic>=2.11.9", + "redis>=5.2.0", # for prompt caching (Redis backend) "rich", "starlette", "termcolor", diff --git a/src/llama_stack/providers/utils/cache/__init__.py b/src/llama_stack/providers/utils/cache/__init__.py new file mode 100644 index 0000000000..9e2d6ffe44 --- /dev/null +++ b/src/llama_stack/providers/utils/cache/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Cache store utilities for prompt caching. + +This module provides cache store abstractions and implementations for use in +the Llama Stack server's prompt caching feature. Supports both memory-based +and Redis-based caching with configurable eviction policies and TTL management. + +Example usage: + from llama_stack.providers.utils.cache import MemoryCacheStore, RedisCacheStore + + # Memory cache for development + memory_cache = MemoryCacheStore(max_entries=1000, eviction_policy="lru") + + # Redis cache for production + redis_cache = RedisCacheStore( + host="localhost", + port=6379, + connection_pool_size=10 + ) +""" + +from .cache_store import CacheError, CacheStore, CircuitBreaker +from .memory import MemoryCacheStore +from .redis import RedisCacheStore + +__all__ = [ + "CacheStore", + "CacheError", + "CircuitBreaker", + "MemoryCacheStore", + "RedisCacheStore", +] diff --git a/src/llama_stack/providers/utils/cache/cache_store.py b/src/llama_stack/providers/utils/cache/cache_store.py new file mode 100644 index 0000000000..2c457ce8e7 --- /dev/null +++ b/src/llama_stack/providers/utils/cache/cache_store.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Cache store abstraction for prompt caching implementation. + +This module provides a protocol-based abstraction for cache storage backends, +enabling flexible storage implementations (memory, Redis, etc.) for prompt +caching in the Llama Stack server. +""" + +from datetime import timedelta +from typing import Any, Optional, Protocol + +from llama_stack.log import get_logger + +logger = get_logger(__name__) + + +class CacheStore(Protocol): + """Protocol defining the cache store interface. + + This protocol specifies the required methods for cache store implementations. + All implementations must support TTL-based expiration and provide efficient + key-value storage operations. + + Methods support both synchronous and asynchronous usage patterns depending + on the implementation requirements. + """ + + async def get(self, key: str) -> Optional[Any]: + """Retrieve a value from the cache. + + Args: + key: Cache key to retrieve + + Returns: + Cached value if present and not expired, None otherwise + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + async def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None, + ) -> None: + """Store a value in the cache with optional TTL. + + Args: + key: Cache key + value: Value to cache (must be serializable) + ttl: Time-to-live in seconds. If None, uses default TTL. + + Raises: + CacheError: If cache backend is unavailable or operation fails + ValueError: If value is not serializable + """ + ... + + async def delete(self, key: str) -> bool: + """Delete a key from the cache. + + Args: + key: Cache key to delete + + Returns: + True if key was deleted, False if key didn't exist + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + async def exists(self, key: str) -> bool: + """Check if a key exists in the cache. + + Args: + key: Cache key to check + + Returns: + True if key exists and is not expired, False otherwise + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + async def ttl(self, key: str) -> Optional[int]: + """Get the remaining TTL for a key. + + Args: + key: Cache key + + Returns: + Remaining TTL in seconds, None if key doesn't exist or has no TTL + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + async def clear(self) -> None: + """Clear all entries from the cache. + + This is primarily useful for testing. Use with caution in production + as it affects all cached data. + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + async def size(self) -> int: + """Get the number of entries in the cache. + + Returns: + Number of cached entries + + Raises: + CacheError: If cache backend is unavailable or operation fails + """ + ... + + +class CacheError(Exception): + """Exception raised for cache operation failures. + + This exception is raised when cache operations fail due to backend + unavailability, network issues, or other operational problems. + The system should gracefully degrade when catching this exception. + """ + + def __init__(self, message: str, cause: Optional[Exception] = None): + """Initialize cache error. + + Args: + message: Error description (should start with "Failed to ...") + cause: Optional underlying exception that caused this error + """ + super().__init__(message) + self.cause = cause + + +class CircuitBreaker: + """Circuit breaker pattern for cache backend failure protection. + + Prevents cascade failures by temporarily disabling cache operations + after detecting repeated failures. Automatically attempts recovery + after a timeout period. + + States: + - CLOSED: Normal operation, requests go through + - OPEN: Too many failures, requests are blocked + - HALF_OPEN: Testing if backend has recovered + + Example: + breaker = CircuitBreaker(failure_threshold=10, recovery_timeout=60) + if breaker.is_closed(): + try: + result = await cache.get(key) + breaker.record_success() + except CacheError: + breaker.record_failure() + """ + + def __init__( + self, + failure_threshold: int = 10, + recovery_timeout: int = 60, + ): + """Initialize circuit breaker. + + Args: + failure_threshold: Number of consecutive failures before opening + recovery_timeout: Seconds to wait before attempting recovery + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.failure_count = 0 + self.last_failure_time: Optional[float] = None + self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN + + def is_closed(self) -> bool: + """Check if circuit breaker allows operations. + + Returns: + True if operations should proceed, False if blocked + """ + import time + + if self.state == "CLOSED": + return True + + if self.state == "OPEN": + # Check if we should try recovery + if ( + self.last_failure_time is not None + and time.time() - self.last_failure_time >= self.recovery_timeout + ): + self.state = "HALF_OPEN" + logger.info("Circuit breaker entering HALF_OPEN state for recovery test") + return True + return False + + # HALF_OPEN state - allow one request through to test + return True + + def record_success(self) -> None: + """Record a successful operation.""" + if self.state == "HALF_OPEN": + logger.info("Circuit breaker recovery successful, returning to CLOSED state") + self.failure_count = 0 + self.last_failure_time = None + self.state = "CLOSED" + + def record_failure(self) -> None: + """Record a failed operation.""" + import time + + self.failure_count += 1 + self.last_failure_time = time.time() + + if self.state == "HALF_OPEN": + # Recovery attempt failed, go back to OPEN + logger.warning("Circuit breaker recovery failed, returning to OPEN state") + self.state = "OPEN" + elif self.failure_count >= self.failure_threshold: + logger.error( + f"Circuit breaker OPEN after {self.failure_count} failures. " + f"Cache operations disabled for {self.recovery_timeout}s" + ) + self.state = "OPEN" + + def get_state(self) -> str: + """Get current circuit breaker state. + + Returns: + Current state: "CLOSED", "OPEN", or "HALF_OPEN" + """ + return self.state + + def reset(self) -> None: + """Manually reset the circuit breaker to CLOSED state. + + This is primarily useful for testing or administrative overrides. + """ + self.failure_count = 0 + self.last_failure_time = None + self.state = "CLOSED" + logger.info("Circuit breaker manually reset to CLOSED state") diff --git a/src/llama_stack/providers/utils/cache/memory.py b/src/llama_stack/providers/utils/cache/memory.py new file mode 100644 index 0000000000..af4ce6e549 --- /dev/null +++ b/src/llama_stack/providers/utils/cache/memory.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""In-memory cache store implementation using cachetools. + +This module provides a memory-based cache store suitable for development +and single-node deployments. For production multi-node deployments, +consider using RedisCacheStore instead. +""" + +import sys +import time +from typing import Any, Literal, Optional + +from cachetools import Cache, LFUCache, LRUCache, TTLCache # type: ignore # no types-cachetools available + +from llama_stack.log import get_logger + +from .cache_store import CacheError + +logger = get_logger(__name__) + + +EvictionPolicy = Literal["lru", "lfu", "ttl-only"] + + +class MemoryCacheStore: + """In-memory cache store with configurable eviction policies. + + This implementation uses the cachetools library to provide efficient + in-memory caching with support for multiple eviction policies: + - LRU (Least Recently Used): Evicts least recently accessed items + - LFU (Least Frequently Used): Evicts least frequently accessed items + - TTL-only: Evicts based on time-to-live only + + Thread-safe for concurrent access within a single process. + + Example: + cache = MemoryCacheStore( + max_entries=1000, + default_ttl=600, + eviction_policy="lru" + ) + await cache.set("key", "value", ttl=300) + value = await cache.get("key") + """ + + def __init__( + self, + max_entries: int = 1000, + max_memory_mb: Optional[int] = 512, + default_ttl: int = 600, + eviction_policy: EvictionPolicy = "lru", + ): + """Initialize memory cache store. + + Args: + max_entries: Maximum number of entries to store + max_memory_mb: Maximum memory usage in MB (soft limit, estimated) + default_ttl: Default time-to-live in seconds + eviction_policy: Eviction strategy ("lru", "lfu", "ttl-only") + + Raises: + ValueError: If invalid parameters provided + """ + if max_entries <= 0: + raise ValueError("max_entries must be positive") + if default_ttl <= 0: + raise ValueError("default_ttl must be positive") + if max_memory_mb is not None and max_memory_mb <= 0: + raise ValueError("max_memory_mb must be positive") + + self.max_entries = max_entries + self.max_memory_mb = max_memory_mb + self.default_ttl = default_ttl + self.eviction_policy = eviction_policy + + # Create appropriate cache implementation + self._cache: Cache = self._create_cache() + self._ttl_map: dict[str, float] = {} # Track expiration times + + logger.info( + f"Initialized MemoryCacheStore: policy={eviction_policy}, " + f"max_entries={max_entries}, max_memory={max_memory_mb}MB, " + f"default_ttl={default_ttl}s" + ) + + def _create_cache(self) -> Cache: + """Create cache instance based on eviction policy. + + Returns: + Cache instance configured with chosen policy + """ + if self.eviction_policy == "lru": + return LRUCache(maxsize=self.max_entries) + elif self.eviction_policy == "lfu": + return LFUCache(maxsize=self.max_entries) + elif self.eviction_policy == "ttl-only": + return TTLCache(maxsize=self.max_entries, ttl=self.default_ttl) + else: + raise ValueError(f"Unknown eviction policy: {self.eviction_policy}") + + def _is_expired(self, key: str) -> bool: + """Check if a key has expired based on TTL. + + Args: + key: Cache key to check + + Returns: + True if key has expired, False otherwise + """ + if key not in self._ttl_map: + return False + + expiration_time = self._ttl_map[key] + if time.time() >= expiration_time: + # Clean up expired entry + self._cache.pop(key, None) + self._ttl_map.pop(key, None) + return True + + return False + + async def get(self, key: str) -> Optional[Any]: + """Retrieve a value from the cache. + + Args: + key: Cache key to retrieve + + Returns: + Cached value if present and not expired, None otherwise + + Raises: + CacheError: If cache operation fails + """ + try: + # Check expiration first + if self._is_expired(key): + return None + + value = self._cache.get(key) + if value is not None: + logger.debug(f"Cache hit: {key}") + return value + + except Exception as e: + logger.error(f"Failed to get cache key '{key}': {e}") + raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e + + async def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None, + ) -> None: + """Store a value in the cache with optional TTL. + + Args: + key: Cache key + value: Value to cache + ttl: Time-to-live in seconds. If None, uses default TTL. + + Raises: + CacheError: If cache operation fails + """ + try: + # Use default TTL if not specified + effective_ttl = ttl if ttl is not None else self.default_ttl + + # Store value + self._cache[key] = value + + # Track expiration time + self._ttl_map[key] = time.time() + effective_ttl + + # Check memory usage (soft limit) + if self.max_memory_mb is not None: + self._check_memory_usage() + + logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)") + + except Exception as e: + logger.error(f"Failed to set cache key '{key}': {e}") + raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e + + def _check_memory_usage(self) -> None: + """Check and log if memory usage exceeds soft limit. + + This is a soft limit - we log warnings but don't enforce hard limits. + The cachetools library will handle eviction based on max_entries. + """ + try: + # Get approximate memory usage + cache_size_bytes = sys.getsizeof(self._cache) + sys.getsizeof(self._ttl_map) + + # Convert to MB + cache_size_mb = cache_size_bytes / (1024 * 1024) + + if self.max_memory_mb is not None and cache_size_mb > self.max_memory_mb: + logger.warning( + f"Cache memory usage ({cache_size_mb:.1f}MB) exceeds " + f"soft limit ({self.max_memory_mb}MB). " + f"Consider increasing max_entries or max_memory_mb." + ) + except Exception as e: + # Don't fail on memory check errors + logger.debug(f"Memory usage check failed: {e}") + + async def delete(self, key: str) -> bool: + """Delete a key from the cache. + + Args: + key: Cache key to delete + + Returns: + True if key was deleted, False if key didn't exist + + Raises: + CacheError: If cache operation fails + """ + try: + existed = key in self._cache + self._cache.pop(key, None) + self._ttl_map.pop(key, None) + + if existed: + logger.debug(f"Cache delete: {key}") + + return existed + + except Exception as e: + logger.error(f"Failed to delete cache key '{key}': {e}") + raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e + + async def exists(self, key: str) -> bool: + """Check if a key exists in the cache. + + Args: + key: Cache key to check + + Returns: + True if key exists and is not expired, False otherwise + + Raises: + CacheError: If cache operation fails + """ + try: + if self._is_expired(key): + return False + return key in self._cache + + except Exception as e: + logger.error(f"Failed to check cache key existence '{key}': {e}") + raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e + + async def ttl(self, key: str) -> Optional[int]: + """Get the remaining TTL for a key. + + Args: + key: Cache key + + Returns: + Remaining TTL in seconds, None if key doesn't exist + + Raises: + CacheError: If cache operation fails + """ + try: + if key not in self._ttl_map: + return None + + if self._is_expired(key): + return None + + remaining = int(self._ttl_map[key] - time.time()) + return max(0, remaining) + + except Exception as e: + logger.error(f"Failed to get TTL for cache key '{key}': {e}") + raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e + + async def clear(self) -> None: + """Clear all entries from the cache. + + Raises: + CacheError: If cache operation fails + """ + try: + self._cache.clear() + self._ttl_map.clear() + logger.info("Cache cleared") + + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + raise CacheError("Failed to clear cache", cause=e) from e + + async def size(self) -> int: + """Get the number of entries in the cache. + + Returns: + Number of cached entries (excluding expired entries) + + Raises: + CacheError: If cache operation fails + """ + try: + # Clean up expired entries first + expired_keys = [ + key for key in list(self._ttl_map.keys()) + if self._is_expired(key) + ] + + return len(self._cache) + + except Exception as e: + logger.error(f"Failed to get cache size: {e}") + raise CacheError("Failed to get cache size", cause=e) from e + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary with cache statistics including size, policy, and limits + """ + return { + "size": len(self._cache), + "max_entries": self.max_entries, + "max_memory_mb": self.max_memory_mb, + "default_ttl": self.default_ttl, + "eviction_policy": self.eviction_policy, + } diff --git a/src/llama_stack/providers/utils/cache/redis.py b/src/llama_stack/providers/utils/cache/redis.py new file mode 100644 index 0000000000..830b8fd670 --- /dev/null +++ b/src/llama_stack/providers/utils/cache/redis.py @@ -0,0 +1,513 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Redis-based cache store implementation. + +This module provides a production-ready Redis cache store with connection +pooling, retry logic, and comprehensive error handling. Suitable for +distributed deployments and high-throughput scenarios. +""" + +import asyncio +import json +from typing import Any, Optional + +from redis import asyncio as aioredis +from redis.asyncio import ConnectionPool, Redis +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +from llama_stack.log import get_logger + +from .cache_store import CacheError + +logger = get_logger(__name__) + + +class RedisCacheStore: + """Redis-based cache store with connection pooling. + + This implementation provides production-ready caching with: + - Connection pooling for efficient resource usage + - Automatic retry logic for transient failures + - Configurable timeouts to prevent blocking + - JSON serialization for complex data types + - Support for Redis cluster and sentinel + + Example: + cache = RedisCacheStore( + host="localhost", + port=6379, + db=0, + password="secret", + connection_pool_size=10, + timeout_ms=100 + ) + await cache.set("key", {"data": "value"}, ttl=300) + value = await cache.get("key") + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + connection_pool_size: int = 10, + timeout_ms: int = 100, + default_ttl: int = 600, + max_retries: int = 3, + key_prefix: str = "llama_stack:", + ): + """Initialize Redis cache store. + + Args: + host: Redis server hostname + port: Redis server port + db: Redis database number (0-15) + password: Optional Redis password + connection_pool_size: Maximum connections in pool + timeout_ms: Operation timeout in milliseconds + default_ttl: Default time-to-live in seconds + max_retries: Maximum retry attempts for failed operations + key_prefix: Prefix for all cache keys (namespace isolation) + + Raises: + ValueError: If invalid parameters provided + """ + if connection_pool_size <= 0: + raise ValueError("connection_pool_size must be positive") + if timeout_ms <= 0: + raise ValueError("timeout_ms must be positive") + if default_ttl <= 0: + raise ValueError("default_ttl must be positive") + if max_retries < 0: + raise ValueError("max_retries must be non-negative") + + self.host = host + self.port = port + self.db = db + self.password = password + self.connection_pool_size = connection_pool_size + self.timeout_ms = timeout_ms + self.default_ttl = default_ttl + self.max_retries = max_retries + self.key_prefix = key_prefix + + # Connection pool (lazy initialization) + self._pool: Optional[ConnectionPool] = None + self._redis: Optional[Redis] = None + + logger.info( + f"Initialized RedisCacheStore: host={host}:{port}, db={db}, " + f"pool_size={connection_pool_size}, timeout={timeout_ms}ms, " + f"default_ttl={default_ttl}s" + ) + + async def _ensure_connection(self) -> Redis: + """Ensure Redis connection is established. + + Returns: + Redis client instance + + Raises: + CacheError: If connection cannot be established + """ + if self._redis is not None: + return self._redis + + try: + # Create connection pool + self._pool = ConnectionPool( + host=self.host, + port=self.port, + db=self.db, + password=self.password, + max_connections=self.connection_pool_size, + socket_timeout=self.timeout_ms / 1000.0, + socket_connect_timeout=self.timeout_ms / 1000.0, + decode_responses=True, + ) + + # Create Redis client + self._redis = Redis(connection_pool=self._pool) + + # Test connection + await asyncio.wait_for( + self._redis.ping(), + timeout=self.timeout_ms / 1000.0 + ) + + logger.info(f"Connected to Redis at {self.host}:{self.port}") + return self._redis + + except (ConnectionError, TimeoutError) as e: + logger.error(f"Failed to connect to Redis: {e}") + raise CacheError(f"Failed to connect to Redis at {self.host}:{self.port}", cause=e) from e + except Exception as e: + logger.error(f"Failed to initialize Redis connection: {e}") + raise CacheError("Failed to initialize Redis connection", cause=e) from e + + def _make_key(self, key: str) -> str: + """Create prefixed cache key for namespace isolation. + + Args: + key: Base cache key + + Returns: + Prefixed key + """ + return f"{self.key_prefix}{key}" + + def _serialize(self, value: Any) -> str: + """Serialize value for storage. + + Args: + value: Value to serialize + + Returns: + JSON-serialized string + + Raises: + ValueError: If value cannot be serialized + """ + try: + return json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError(f"Value is not JSON-serializable: {e}") from e + + def _deserialize(self, data: str) -> Any: + """Deserialize stored value. + + Args: + data: JSON-serialized string + + Returns: + Deserialized value + + Raises: + ValueError: If data cannot be deserialized + """ + try: + return json.loads(data) + except (TypeError, ValueError) as e: + logger.warning(f"Failed to deserialize cache value: {e}") + return None + + async def _retry_operation(self, operation, *args, **kwargs) -> Any: + """Retry an operation with exponential backoff. + + Args: + operation: Async function to retry + *args: Positional arguments for operation + **kwargs: Keyword arguments for operation + + Returns: + Operation result + + Raises: + CacheError: If all retries fail + """ + last_error = None + + for attempt in range(self.max_retries + 1): + try: + return await asyncio.wait_for( + operation(*args, **kwargs), + timeout=self.timeout_ms / 1000.0 + ) + except (ConnectionError, TimeoutError) as e: + last_error = e + if attempt < self.max_retries: + backoff = 2 ** attempt * 0.1 # 100ms, 200ms, 400ms + logger.warning( + f"Redis operation failed (attempt {attempt + 1}/{self.max_retries + 1}), " + f"retrying in {backoff}s: {e}" + ) + await asyncio.sleep(backoff) + else: + logger.error(f"Redis operation failed after {self.max_retries + 1} attempts") + except Exception as e: + # Don't retry on non-transient errors + raise CacheError(f"Redis operation failed: {e}", cause=e) from e + + raise CacheError(f"Redis operation failed after {self.max_retries + 1} attempts", cause=last_error) from last_error + + async def get(self, key: str) -> Optional[Any]: + """Retrieve a value from the cache. + + Args: + key: Cache key to retrieve + + Returns: + Cached value if present and not expired, None otherwise + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + prefixed_key = self._make_key(key) + + data = await self._retry_operation(redis.get, prefixed_key) + + if data is None: + return None + + value = self._deserialize(data) + if value is not None: + logger.debug(f"Cache hit: {key}") + return value + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to get cache key '{key}': {e}") + raise CacheError(f"Failed to get cache key '{key}'", cause=e) from e + + async def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None, + ) -> None: + """Store a value in the cache with optional TTL. + + Args: + key: Cache key + value: Value to cache (must be JSON-serializable) + ttl: Time-to-live in seconds. If None, uses default TTL. + + Raises: + CacheError: If cache operation fails + ValueError: If value is not serializable + """ + try: + redis = await self._ensure_connection() + prefixed_key = self._make_key(key) + + # Serialize value + data = self._serialize(value) + + # Use default TTL if not specified + effective_ttl = ttl if ttl is not None else self.default_ttl + + # Store with TTL + await self._retry_operation( + redis.setex, + prefixed_key, + effective_ttl, + data + ) + + logger.debug(f"Cache set: {key} (ttl={effective_ttl}s)") + + except ValueError: + raise + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to set cache key '{key}': {e}") + raise CacheError(f"Failed to set cache key '{key}'", cause=e) from e + + async def delete(self, key: str) -> bool: + """Delete a key from the cache. + + Args: + key: Cache key to delete + + Returns: + True if key was deleted, False if key didn't exist + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + prefixed_key = self._make_key(key) + + deleted_count = await self._retry_operation(redis.delete, prefixed_key) + + if deleted_count > 0: + logger.debug(f"Cache delete: {key}") + + return bool(deleted_count > 0) + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to delete cache key '{key}': {e}") + raise CacheError(f"Failed to delete cache key '{key}'", cause=e) from e + + async def exists(self, key: str) -> bool: + """Check if a key exists in the cache. + + Args: + key: Cache key to check + + Returns: + True if key exists and is not expired, False otherwise + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + prefixed_key = self._make_key(key) + + exists = await self._retry_operation(redis.exists, prefixed_key) + return bool(exists > 0) + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to check cache key existence '{key}': {e}") + raise CacheError(f"Failed to check cache key existence '{key}'", cause=e) from e + + async def ttl(self, key: str) -> Optional[int]: + """Get the remaining TTL for a key. + + Args: + key: Cache key + + Returns: + Remaining TTL in seconds, None if key doesn't exist or has no TTL + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + prefixed_key = self._make_key(key) + + ttl_seconds = await self._retry_operation(redis.ttl, prefixed_key) + + # Redis returns -2 if key doesn't exist, -1 if no TTL + if ttl_seconds == -2: + return None + if ttl_seconds == -1: + return None + + return int(max(0, ttl_seconds)) + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to get TTL for cache key '{key}': {e}") + raise CacheError(f"Failed to get TTL for cache key '{key}'", cause=e) from e + + async def clear(self) -> None: + """Clear all entries from the cache. + + This deletes all keys matching the key_prefix pattern. + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + pattern = f"{self.key_prefix}*" + + # Scan and delete keys matching pattern + cursor = 0 + deleted_total = 0 + + while True: + cursor, keys = await self._retry_operation( + redis.scan, + cursor=cursor, + match=pattern, + count=100 + ) + + if keys: + deleted_count = await self._retry_operation(redis.delete, *keys) + deleted_total += deleted_count + + if cursor == 0: + break + + logger.info(f"Cache cleared: deleted {deleted_total} keys") + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + raise CacheError("Failed to clear cache", cause=e) from e + + async def size(self) -> int: + """Get the number of entries in the cache. + + Returns: + Number of cached entries matching key_prefix + + Raises: + CacheError: If cache operation fails + """ + try: + redis = await self._ensure_connection() + pattern = f"{self.key_prefix}*" + + # Count keys matching pattern + cursor = 0 + count = 0 + + while True: + cursor, keys = await self._retry_operation( + redis.scan, + cursor=cursor, + match=pattern, + count=100 + ) + + count += len(keys) + + if cursor == 0: + break + + return count + + except CacheError: + raise + except Exception as e: + logger.error(f"Failed to get cache size: {e}") + raise CacheError("Failed to get cache size", cause=e) from e + + async def close(self) -> None: + """Close Redis connection and cleanup resources. + + This should be called when the cache is no longer needed. + """ + try: + if self._redis is not None: + await self._redis.close() + self._redis = None + + if self._pool is not None: + await self._pool.disconnect() + self._pool = None + + logger.info("Redis connection closed") + + except Exception as e: + logger.warning(f"Error closing Redis connection: {e}") + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary with cache configuration and connection info + """ + return { + "host": self.host, + "port": self.port, + "db": self.db, + "connection_pool_size": self.connection_pool_size, + "timeout_ms": self.timeout_ms, + "default_ttl": self.default_ttl, + "max_retries": self.max_retries, + "key_prefix": self.key_prefix, + "connected": self._redis is not None, + } diff --git a/tests/unit/providers/utils/cache/__init__.py b/tests/unit/providers/utils/cache/__init__.py new file mode 100644 index 0000000000..53fafda8e9 --- /dev/null +++ b/tests/unit/providers/utils/cache/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for cache store implementations.""" diff --git a/tests/unit/providers/utils/cache/test_cache_store.py b/tests/unit/providers/utils/cache/test_cache_store.py new file mode 100644 index 0000000000..3ac31505bd --- /dev/null +++ b/tests/unit/providers/utils/cache/test_cache_store.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for cache store base classes and utilities.""" + +import asyncio + +import pytest + +from llama_stack.providers.utils.cache import CacheError, CircuitBreaker + + +class TestCacheError: + """Test suite for CacheError exception.""" + + def test_init_with_message(self): + """Test CacheError initialization with message.""" + error = CacheError("Failed to connect to cache") + assert str(error) == "Failed to connect to cache" + assert error.cause is None + + def test_init_with_cause(self): + """Test CacheError initialization with underlying cause.""" + cause = ValueError("Invalid value") + error = CacheError("Failed to set cache key", cause=cause) + assert str(error) == "Failed to set cache key" + assert error.cause == cause + + +class TestCircuitBreaker: + """Test suite for CircuitBreaker.""" + + def test_init_default_params(self): + """Test initialization with default parameters.""" + breaker = CircuitBreaker() + assert breaker.failure_threshold == 10 + assert breaker.recovery_timeout == 60 + assert breaker.failure_count == 0 + assert breaker.last_failure_time is None + assert breaker.state == "CLOSED" + + def test_init_custom_params(self): + """Test initialization with custom parameters.""" + breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30) + assert breaker.failure_threshold == 5 + assert breaker.recovery_timeout == 30 + + def test_is_closed_initial_state(self): + """Test is_closed in initial state.""" + breaker = CircuitBreaker() + assert breaker.is_closed() is True + assert breaker.get_state() == "CLOSED" + + def test_record_success(self): + """Test recording successful operations.""" + breaker = CircuitBreaker() + + # Record some failures + breaker.record_failure() + breaker.record_failure() + assert breaker.failure_count == 2 + + # Record success should reset + breaker.record_success() + assert breaker.failure_count == 0 + assert breaker.last_failure_time is None + assert breaker.state == "CLOSED" + + def test_record_failure_below_threshold(self): + """Test recording failures below threshold.""" + breaker = CircuitBreaker(failure_threshold=5) + + # Record failures below threshold + for i in range(4): + breaker.record_failure() + assert breaker.is_closed() is True + assert breaker.state == "CLOSED" + + assert breaker.failure_count == 4 + + def test_record_failure_reach_threshold(self): + """Test circuit breaker opens when threshold reached.""" + breaker = CircuitBreaker(failure_threshold=3) + + # Record failures to reach threshold + for i in range(3): + breaker.record_failure() + + # Should be open now + assert breaker.state == "OPEN" + assert breaker.is_closed() is False + + def test_circuit_open_blocks_requests(self): + """Test that open circuit blocks requests.""" + breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=10) + + # Open the circuit + for i in range(3): + breaker.record_failure() + + assert breaker.is_closed() is False + assert breaker.state == "OPEN" + + async def test_recovery_timeout(self): + """Test circuit breaker recovery after timeout.""" + breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1) + + # Open the circuit + for i in range(3): + breaker.record_failure() + + assert breaker.state == "OPEN" + assert breaker.is_closed() is False + + # Wait for recovery timeout + await asyncio.sleep(1.1) + + # Should enter HALF_OPEN state + assert breaker.is_closed() is True + assert breaker.state == "HALF_OPEN" + + async def test_half_open_success_closes_circuit(self): + """Test successful request in HALF_OPEN closes circuit.""" + breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1) + + # Open the circuit + for i in range(3): + breaker.record_failure() + + # Wait for recovery + await asyncio.sleep(1.1) + + # Trigger state transition by calling is_closed() + assert breaker.is_closed() is True + assert breaker.state == "HALF_OPEN" + + # Record success + breaker.record_success() + assert breaker.state == "CLOSED" + assert breaker.failure_count == 0 + + async def test_half_open_failure_reopens_circuit(self): + """Test failed request in HALF_OPEN reopens circuit.""" + breaker = CircuitBreaker(failure_threshold=3, recovery_timeout=1) + + # Open the circuit + for i in range(3): + breaker.record_failure() + + # Wait for recovery + await asyncio.sleep(1.1) + + # Trigger state transition by calling is_closed() + assert breaker.is_closed() is True + assert breaker.state == "HALF_OPEN" + + # Record failure + breaker.record_failure() + assert breaker.state == "OPEN" + + def test_reset(self): + """Test manual reset of circuit breaker.""" + breaker = CircuitBreaker(failure_threshold=3) + + # Open the circuit + for i in range(3): + breaker.record_failure() + + assert breaker.state == "OPEN" + + # Manual reset + breaker.reset() + assert breaker.state == "CLOSED" + assert breaker.failure_count == 0 + assert breaker.last_failure_time is None + + def test_get_state(self): + """Test getting circuit breaker state.""" + breaker = CircuitBreaker(failure_threshold=3) + + # Initial state + assert breaker.get_state() == "CLOSED" + + # After failures + breaker.record_failure() + assert breaker.get_state() == "CLOSED" + + # Open state + for i in range(2): + breaker.record_failure() + assert breaker.get_state() == "OPEN" + + async def test_multiple_recovery_attempts(self): + """Test multiple recovery attempts.""" + breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=1) + + # Open the circuit + breaker.record_failure() + breaker.record_failure() + assert breaker.state == "OPEN" + + # First recovery attempt fails + await asyncio.sleep(1.1) + assert breaker.is_closed() is True # Trigger state check + assert breaker.state == "HALF_OPEN" + breaker.record_failure() + assert breaker.state == "OPEN" + + # Second recovery attempt succeeds + await asyncio.sleep(1.1) + assert breaker.is_closed() is True # Trigger state check + assert breaker.state == "HALF_OPEN" + breaker.record_success() + assert breaker.state == "CLOSED" + + def test_failure_count_tracking(self): + """Test failure count tracking.""" + breaker = CircuitBreaker(failure_threshold=5) + + # Track failures + assert breaker.failure_count == 0 + + breaker.record_failure() + assert breaker.failure_count == 1 + + breaker.record_failure() + assert breaker.failure_count == 2 + + # Success resets count + breaker.record_success() + assert breaker.failure_count == 0 + + async def test_concurrent_operations(self): + """Test circuit breaker with concurrent operations.""" + breaker = CircuitBreaker(failure_threshold=10) + + async def record_failures(count: int): + for _ in range(count): + breaker.record_failure() + await asyncio.sleep(0.01) + + # Concurrent failures + await asyncio.gather( + record_failures(3), + record_failures(3), + record_failures(3), + ) + + assert breaker.failure_count == 9 + assert breaker.state == "CLOSED" + + # One more should open it + breaker.record_failure() + assert breaker.state == "OPEN" diff --git a/tests/unit/providers/utils/cache/test_memory_cache.py b/tests/unit/providers/utils/cache/test_memory_cache.py new file mode 100644 index 0000000000..b6ff278999 --- /dev/null +++ b/tests/unit/providers/utils/cache/test_memory_cache.py @@ -0,0 +1,332 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for MemoryCacheStore implementation.""" + +import asyncio + +import pytest + +from llama_stack.providers.utils.cache import CacheError, MemoryCacheStore + + +class TestMemoryCacheStore: + """Test suite for MemoryCacheStore.""" + + async def test_init_default_params(self): + """Test initialization with default parameters.""" + cache = MemoryCacheStore() + assert cache.max_entries == 1000 + assert cache.max_memory_mb == 512 + assert cache.default_ttl == 600 + assert cache.eviction_policy == "lru" + + async def test_init_custom_params(self): + """Test initialization with custom parameters.""" + cache = MemoryCacheStore( + max_entries=500, + max_memory_mb=256, + default_ttl=300, + eviction_policy="lfu", + ) + assert cache.max_entries == 500 + assert cache.max_memory_mb == 256 + assert cache.default_ttl == 300 + assert cache.eviction_policy == "lfu" + + async def test_init_invalid_params(self): + """Test initialization with invalid parameters.""" + with pytest.raises(ValueError, match="max_entries must be positive"): + MemoryCacheStore(max_entries=0) + + with pytest.raises(ValueError, match="default_ttl must be positive"): + MemoryCacheStore(default_ttl=0) + + with pytest.raises(ValueError, match="max_memory_mb must be positive"): + MemoryCacheStore(max_memory_mb=0) + + with pytest.raises(ValueError, match="Unknown eviction policy"): + MemoryCacheStore(eviction_policy="invalid") # type: ignore + + async def test_set_and_get(self): + """Test basic set and get operations.""" + cache = MemoryCacheStore() + + # Set value + await cache.set("key1", "value1") + + # Get value + value = await cache.get("key1") + assert value == "value1" + + async def test_get_nonexistent_key(self): + """Test getting a non-existent key.""" + cache = MemoryCacheStore() + value = await cache.get("nonexistent") + assert value is None + + async def test_set_with_custom_ttl(self): + """Test setting value with custom TTL.""" + cache = MemoryCacheStore(default_ttl=10) + + # Set with custom TTL + await cache.set("key1", "value1", ttl=1) + + # Value should exist initially + value = await cache.get("key1") + assert value == "value1" + + # Wait for expiration + await asyncio.sleep(1.1) + + # Value should be expired + value = await cache.get("key1") + assert value is None + + async def test_set_complex_value(self): + """Test storing complex data types.""" + cache = MemoryCacheStore() + + # Test dictionary + data = {"nested": {"key": "value"}, "list": [1, 2, 3]} + await cache.set("complex", data) + value = await cache.get("complex") + assert value == data + + # Test list + list_data = [1, "two", {"three": 3}] + await cache.set("list", list_data) + value = await cache.get("list") + assert value == list_data + + async def test_delete(self): + """Test deleting a key.""" + cache = MemoryCacheStore() + + # Set and delete + await cache.set("key1", "value1") + deleted = await cache.delete("key1") + assert deleted is True + + # Verify deleted + value = await cache.get("key1") + assert value is None + + # Delete non-existent key + deleted = await cache.delete("nonexistent") + assert deleted is False + + async def test_exists(self): + """Test checking key existence.""" + cache = MemoryCacheStore() + + # Non-existent key + exists = await cache.exists("key1") + assert exists is False + + # Existing key + await cache.set("key1", "value1") + exists = await cache.exists("key1") + assert exists is True + + # Expired key + await cache.set("key2", "value2", ttl=1) + await asyncio.sleep(1.1) + exists = await cache.exists("key2") + assert exists is False + + async def test_ttl(self): + """Test getting remaining TTL.""" + cache = MemoryCacheStore() + + # Non-existent key + ttl = await cache.ttl("nonexistent") + assert ttl is None + + # Key with TTL + await cache.set("key1", "value1", ttl=10) + ttl = await cache.ttl("key1") + assert ttl is not None + assert 8 <= ttl <= 10 # Allow some tolerance + + # Expired key + await cache.set("key2", "value2", ttl=1) + await asyncio.sleep(1.1) + ttl = await cache.ttl("key2") + assert ttl is None + + async def test_clear(self): + """Test clearing all entries.""" + cache = MemoryCacheStore() + + # Add multiple entries + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + + # Clear + await cache.clear() + + # Verify all cleared + assert await cache.get("key1") is None + assert await cache.get("key2") is None + assert await cache.get("key3") is None + + async def test_size(self): + """Test getting cache size.""" + cache = MemoryCacheStore() + + # Empty cache + size = await cache.size() + assert size == 0 + + # Add entries + await cache.set("key1", "value1") + await cache.set("key2", "value2") + size = await cache.size() + assert size == 2 + + # Delete entry + await cache.delete("key1") + size = await cache.size() + assert size == 1 + + # Clear cache + await cache.clear() + size = await cache.size() + assert size == 0 + + async def test_lru_eviction(self): + """Test LRU eviction policy.""" + cache = MemoryCacheStore(max_entries=3, eviction_policy="lru") + + # Fill cache + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + + # Access key1 to make it recently used + await cache.get("key1") + + # Add new entry, should evict key2 (least recently used) + await cache.set("key4", "value4") + + # key2 should be evicted + assert await cache.get("key1") == "value1" + assert await cache.get("key2") is None + assert await cache.get("key3") == "value3" + assert await cache.get("key4") == "value4" + + async def test_lfu_eviction(self): + """Test LFU eviction policy.""" + cache = MemoryCacheStore(max_entries=3, eviction_policy="lfu") + + # Fill cache + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + + # Access key1 multiple times + await cache.get("key1") + await cache.get("key1") + await cache.get("key1") + + # Access key2 twice + await cache.get("key2") + await cache.get("key2") + + # key3 accessed once (least frequently) + + # Add new entry, should evict key3 (least frequently used) + await cache.set("key4", "value4") + + # key3 should be evicted + assert await cache.get("key1") == "value1" + assert await cache.get("key2") == "value2" + assert await cache.get("key3") is None + assert await cache.get("key4") == "value4" + + async def test_concurrent_access(self): + """Test concurrent access to cache.""" + cache = MemoryCacheStore() + + async def set_value(key: str, value: str): + await cache.set(key, value) + + async def get_value(key: str): + return await cache.get(key) + + # Concurrent sets + await asyncio.gather( + set_value("key1", "value1"), + set_value("key2", "value2"), + set_value("key3", "value3"), + ) + + # Concurrent gets + results = await asyncio.gather( + get_value("key1"), + get_value("key2"), + get_value("key3"), + ) + + assert results == ["value1", "value2", "value3"] + + async def test_update_existing_key(self): + """Test updating an existing key.""" + cache = MemoryCacheStore() + + # Set initial value + await cache.set("key1", "value1") + assert await cache.get("key1") == "value1" + + # Update value + await cache.set("key1", "value2") + assert await cache.get("key1") == "value2" + + async def test_get_stats(self): + """Test getting cache statistics.""" + cache = MemoryCacheStore( + max_entries=100, + max_memory_mb=128, + default_ttl=300, + eviction_policy="lru", + ) + + await cache.set("key1", "value1") + await cache.set("key2", "value2") + + stats = cache.get_stats() + + assert stats["size"] == 2 + assert stats["max_entries"] == 100 + assert stats["max_memory_mb"] == 128 + assert stats["default_ttl"] == 300 + assert stats["eviction_policy"] == "lru" + + async def test_ttl_expiration_cleanup(self): + """Test that expired entries are cleaned up properly.""" + cache = MemoryCacheStore() + + # Set entry with short TTL + await cache.set("key1", "value1", ttl=1) + await cache.set("key2", "value2", ttl=10) + + # Initially both exist + assert await cache.size() == 2 + + # Wait for key1 to expire + await asyncio.sleep(1.1) + + # Accessing expired key should clean it up + assert await cache.get("key1") is None + + # Size should reflect cleanup + size = await cache.size() + assert size == 1 + + # key2 should still exist + assert await cache.get("key2") == "value2" diff --git a/tests/unit/providers/utils/cache/test_redis_cache.py b/tests/unit/providers/utils/cache/test_redis_cache.py new file mode 100644 index 0000000000..aef5da31d9 --- /dev/null +++ b/tests/unit/providers/utils/cache/test_redis_cache.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Unit tests for RedisCacheStore implementation.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack.providers.utils.cache import CacheError, RedisCacheStore + + +class TestRedisCacheStore: + """Test suite for RedisCacheStore.""" + + async def test_init_default_params(self): + """Test initialization with default parameters.""" + cache = RedisCacheStore() + assert cache.host == "localhost" + assert cache.port == 6379 + assert cache.db == 0 + assert cache.password is None + assert cache.connection_pool_size == 10 + assert cache.timeout_ms == 100 + assert cache.default_ttl == 600 + assert cache.max_retries == 3 + assert cache.key_prefix == "llama_stack:" + + async def test_init_custom_params(self): + """Test initialization with custom parameters.""" + cache = RedisCacheStore( + host="redis.example.com", + port=6380, + db=1, + password="secret", + connection_pool_size=20, + timeout_ms=200, + default_ttl=300, + max_retries=5, + key_prefix="test:", + ) + assert cache.host == "redis.example.com" + assert cache.port == 6380 + assert cache.db == 1 + assert cache.password == "secret" + assert cache.connection_pool_size == 20 + assert cache.timeout_ms == 200 + assert cache.default_ttl == 300 + assert cache.max_retries == 5 + assert cache.key_prefix == "test:" + + async def test_init_invalid_params(self): + """Test initialization with invalid parameters.""" + with pytest.raises(ValueError, match="connection_pool_size must be positive"): + RedisCacheStore(connection_pool_size=0) + + with pytest.raises(ValueError, match="timeout_ms must be positive"): + RedisCacheStore(timeout_ms=0) + + with pytest.raises(ValueError, match="default_ttl must be positive"): + RedisCacheStore(default_ttl=0) + + with pytest.raises(ValueError, match="max_retries must be non-negative"): + RedisCacheStore(max_retries=-1) + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_ensure_connection(self, mock_redis_class, mock_pool_class): + """Test connection establishment.""" + # Setup mocks + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Ensure connection + redis = await cache._ensure_connection() + + # Verify connection was established + assert redis == mock_redis + mock_pool_class.assert_called_once() + mock_redis.ping.assert_called_once() + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_connection_failure(self, mock_redis_class, mock_pool_class): + """Test connection failure handling.""" + from redis.exceptions import ConnectionError as RedisConnectionError + + # Setup mocks to fail + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(side_effect=RedisConnectionError("Connection refused")) + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Connection should fail + with pytest.raises(CacheError, match="Failed to connect to Redis"): + await cache._ensure_connection() + + def test_make_key(self): + """Test key prefixing.""" + cache = RedisCacheStore(key_prefix="test:") + assert cache._make_key("mykey") == "test:mykey" + assert cache._make_key("another") == "test:another" + + def test_serialize_deserialize(self): + """Test value serialization.""" + cache = RedisCacheStore() + + # Simple value + assert cache._serialize("hello") == '"hello"' + assert cache._deserialize('"hello"') == "hello" + + # Dictionary + data = {"key": "value", "number": 42} + serialized = cache._serialize(data) + assert cache._deserialize(serialized) == data + + # List + list_data = [1, 2, "three"] + serialized = cache._serialize(list_data) + assert cache._deserialize(serialized) == list_data + + def test_serialize_error(self): + """Test serialization error handling.""" + cache = RedisCacheStore() + + # Object that can't be serialized + class NonSerializable: + pass + + with pytest.raises(ValueError, match="Value is not JSON-serializable"): + cache._serialize(NonSerializable()) + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_set_and_get(self, mock_redis_class, mock_pool_class): + """Test set and get operations.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.get = AsyncMock(return_value=json.dumps("value1")) + mock_redis.setex = AsyncMock() + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Set value + await cache.set("key1", "value1") + mock_redis.setex.assert_called_once() + + # Get value + value = await cache.get("key1") + assert value == "value1" + mock_redis.get.assert_called_once() + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_get_nonexistent_key(self, mock_redis_class, mock_pool_class): + """Test getting a non-existent key.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Get non-existent key + value = await cache.get("nonexistent") + assert value is None + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_set_with_custom_ttl(self, mock_redis_class, mock_pool_class): + """Test setting value with custom TTL.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.setex = AsyncMock() + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore(default_ttl=600) + + # Set with custom TTL + await cache.set("key1", "value1", ttl=300) + + # Verify setex was called with custom TTL + call_args = mock_redis.setex.call_args + assert call_args[0][1] == 300 # TTL argument + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_delete(self, mock_redis_class, mock_pool_class): + """Test deleting a key.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.delete = AsyncMock(return_value=1) # 1 key deleted + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Delete key + deleted = await cache.delete("key1") + assert deleted is True + + # Delete non-existent key + mock_redis.delete = AsyncMock(return_value=0) + deleted = await cache.delete("nonexistent") + assert deleted is False + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_exists(self, mock_redis_class, mock_pool_class): + """Test checking key existence.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.exists = AsyncMock(return_value=1) # Exists + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Check existing key + exists = await cache.exists("key1") + assert exists is True + + # Check non-existent key + mock_redis.exists = AsyncMock(return_value=0) + exists = await cache.exists("nonexistent") + assert exists is False + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_ttl(self, mock_redis_class, mock_pool_class): + """Test getting remaining TTL.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.ttl = AsyncMock(return_value=300) + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Get TTL + ttl = await cache.ttl("key1") + assert ttl == 300 + + # Key doesn't exist + mock_redis.ttl = AsyncMock(return_value=-2) + ttl = await cache.ttl("nonexistent") + assert ttl is None + + # Key has no TTL + mock_redis.ttl = AsyncMock(return_value=-1) + ttl = await cache.ttl("no_ttl_key") + assert ttl is None + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_clear(self, mock_redis_class, mock_pool_class): + """Test clearing all entries.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.scan = AsyncMock( + side_effect=[ + (10, ["llama_stack:key1", "llama_stack:key2"]), + (0, ["llama_stack:key3"]), # cursor 0 indicates end + ] + ) + mock_redis.delete = AsyncMock(return_value=3) + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Clear cache + await cache.clear() + + # Verify scan and delete were called + assert mock_redis.scan.call_count == 2 + mock_redis.delete.assert_called() + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_size(self, mock_redis_class, mock_pool_class): + """Test getting cache size.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.scan = AsyncMock( + side_effect=[ + (10, ["llama_stack:key1", "llama_stack:key2"]), + (0, ["llama_stack:key3"]), + ] + ) + mock_redis_class.return_value = mock_redis + + # Create cache + cache = RedisCacheStore() + + # Get size + size = await cache.size() + assert size == 3 + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_retry_logic(self, mock_redis_class, mock_pool_class): + """Test retry logic for transient failures.""" + from redis.exceptions import TimeoutError as RedisTimeoutError + + # Setup mocks - fail twice, then succeed + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.get = AsyncMock( + side_effect=[ + RedisTimeoutError("Timeout"), + RedisTimeoutError("Timeout"), + json.dumps("success"), + ] + ) + mock_redis_class.return_value = mock_redis + + # Create cache with retries + cache = RedisCacheStore(max_retries=3) + + # Should succeed after retries + value = await cache.get("key1") + assert value == "success" + assert mock_redis.get.call_count == 3 + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_retry_exhaustion(self, mock_redis_class, mock_pool_class): + """Test behavior when all retries are exhausted.""" + from redis.exceptions import TimeoutError as RedisTimeoutError + + # Setup mocks - always fail + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.get = AsyncMock(side_effect=RedisTimeoutError("Timeout")) + mock_redis_class.return_value = mock_redis + + # Create cache with limited retries + cache = RedisCacheStore(max_retries=2) + + # Should raise CacheError after exhausting retries + with pytest.raises(CacheError, match="failed after .* attempts"): + await cache.get("key1") + + # Should have tried 3 times (initial + 2 retries) + assert mock_redis.get.call_count == 3 + + @patch("llama_stack.providers.utils.cache.redis.ConnectionPool") + @patch("llama_stack.providers.utils.cache.redis.Redis") + async def test_close(self, mock_redis_class, mock_pool_class): + """Test closing Redis connection.""" + # Setup mocks + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + mock_redis.close = AsyncMock() + mock_redis_class.return_value = mock_redis + + mock_pool = AsyncMock() + mock_pool.disconnect = AsyncMock() + mock_pool_class.return_value = mock_pool + + # Create cache and establish connection + cache = RedisCacheStore() + await cache._ensure_connection() + + # Close connection + await cache.close() + + # Verify cleanup + mock_redis.close.assert_called_once() + mock_pool.disconnect.assert_called_once() + + def test_get_stats(self): + """Test getting cache statistics.""" + cache = RedisCacheStore( + host="redis.example.com", + port=6380, + db=1, + connection_pool_size=20, + timeout_ms=200, + default_ttl=300, + max_retries=5, + key_prefix="test:", + ) + + stats = cache.get_stats() + + assert stats["host"] == "redis.example.com" + assert stats["port"] == 6380 + assert stats["db"] == 1 + assert stats["connection_pool_size"] == 20 + assert stats["timeout_ms"] == 200 + assert stats["default_ttl"] == 300 + assert stats["max_retries"] == 5 + assert stats["key_prefix"] == "test:" + assert stats["connected"] is False # Not connected yet diff --git a/uv.lock b/uv.lock index a343eb5d87..04bda7e823 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -1996,6 +1996,7 @@ dependencies = [ { name = "aiohttp" }, { name = "aiosqlite" }, { name = "asyncpg" }, + { name = "cachetools" }, { name = "fastapi" }, { name = "fire" }, { name = "h11" }, @@ -2013,6 +2014,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-multipart" }, { name = "pyyaml" }, + { name = "redis" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "starlette" }, @@ -2147,6 +2149,7 @@ requires-dist = [ { name = "aiohttp" }, { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "asyncpg" }, + { name = "cachetools", specifier = ">=5.5.0" }, { name = "fastapi", specifier = ">=0.115.0,<1.0" }, { name = "fire" }, { name = "h11", specifier = ">=0.16.0" }, @@ -2166,6 +2169,7 @@ requires-dist = [ { name = "python-multipart", specifier = ">=0.0.20" }, { name = "pyyaml", specifier = ">=6.0" }, { name = "pyyaml", specifier = ">=6.0.2" }, + { name = "redis", specifier = ">=5.2.0" }, { name = "rich" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "starlette" }, @@ -4398,6 +4402,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/33/d8df6a2b214ffbe4138db9a1efe3248f67dc3c671f82308bea1582ecbbb7/qdrant_client-1.15.1-py3-none-any.whl", hash = "sha256:2b975099b378382f6ca1cfb43f0d59e541be6e16a5892f282a4b8de7eff5cb63", size = 337331, upload-time = "2025-07-31T19:35:17.539Z" }, ] +[[package]] +name = "redis" +version = "7.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" }, +] + [[package]] name = "referencing" version = "0.36.2" @@ -4656,6 +4669,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" }, { url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" }, { url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" }, + { url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" }, ] [[package]]