Skip to content

Commit a588d4c

Browse files
hazaiPouyanpi
andauthored
feat(cache): Add LFU caching system for models (currently applied to content safety checks) #1436
Implement a pluggable caching infrastructure to reduce redundant LLM calls in content safety checks. The system features a Least Frequently Used (LFU) eviction policy with optional statistics tracking and periodic logging. Key components: - CacheInterface: Abstract base defining cache contract - LFUCache: Thread-safe LFU implementation with configurable stats - Cache utilities: Key normalization, LLM stats extraction/restoration - Content safety integration: Automatic caching in check_input action - Configuration: Cache settings in RailsConfig with per-model caches The caching layer is transparent to existing code and can be enabled via configuration without code changes. --------- Signed-off-by: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Co-authored-by: Pouyan <13303554+Pouyanpi@users.noreply.github.com>
1 parent 81e7c0b commit a588d4c

File tree

18 files changed

+3304
-2
lines changed

18 files changed

+3304
-2
lines changed

nemoguardrails/library/content_safety/actions.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
from nemoguardrails.actions.actions import action
2222
from nemoguardrails.actions.llm.utils import llm_call
2323
from nemoguardrails.context import llm_call_info_var
24+
from nemoguardrails.llm.cache import CacheInterface
25+
from nemoguardrails.llm.cache.utils import (
26+
CacheEntry,
27+
create_normalized_cache_key,
28+
extract_llm_stats_for_cache,
29+
get_from_cache_and_restore_stats,
30+
)
2431
from nemoguardrails.llm.taskmanager import LLMTaskManager
2532
from nemoguardrails.logging.explain import LLMCallInfo
2633

@@ -33,6 +40,7 @@ async def content_safety_check_input(
3340
llm_task_manager: LLMTaskManager,
3441
model_name: Optional[str] = None,
3542
context: Optional[dict] = None,
43+
model_caches: Optional[Dict[str, CacheInterface]] = None,
3644
**kwargs,
3745
) -> dict:
3846
_MAX_TOKENS = 3
@@ -75,6 +83,15 @@ async def content_safety_check_input(
7583

7684
max_tokens = max_tokens or _MAX_TOKENS
7785

86+
cache = model_caches.get(model_name) if model_caches else None
87+
88+
if cache:
89+
cache_key = create_normalized_cache_key(check_input_prompt)
90+
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
91+
if cached_result is not None:
92+
log.debug(f"Content safety cache hit for model '{model_name}'")
93+
return cached_result
94+
7895
result = await llm_call(
7996
llm,
8097
check_input_prompt,
@@ -86,7 +103,18 @@ async def content_safety_check_input(
86103

87104
is_safe, *violated_policies = result
88105

89-
return {"allowed": is_safe, "policy_violations": violated_policies}
106+
final_result = {"allowed": is_safe, "policy_violations": violated_policies}
107+
108+
if cache:
109+
cache_key = create_normalized_cache_key(check_input_prompt)
110+
cache_entry: CacheEntry = {
111+
"result": final_result,
112+
"llm_stats": extract_llm_stats_for_cache(),
113+
}
114+
cache.put(cache_key, cache_entry)
115+
log.debug(f"Content safety result cached for model '{model_name}'")
116+
117+
return final_result
90118

91119

92120
def content_safety_check_output_mapping(result: dict) -> bool:
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""General-purpose caching utilities for NeMo Guardrails."""
17+
18+
from nemoguardrails.llm.cache.interface import CacheInterface
19+
from nemoguardrails.llm.cache.lfu import LFUCache
20+
21+
__all__ = ["CacheInterface", "LFUCache"]
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Cache interface for NeMo Guardrails caching system.
18+
19+
This module defines the abstract base class for cache implementations
20+
that can be used interchangeably throughout the guardrails system.
21+
"""
22+
23+
from abc import ABC, abstractmethod
24+
from typing import Any, Callable, Optional
25+
26+
27+
class CacheInterface(ABC):
28+
"""
29+
Abstract base class defining the interface for cache implementations.
30+
31+
All cache implementations must inherit from this class and implement
32+
the required methods to ensure compatibility with the caching system.
33+
"""
34+
35+
@abstractmethod
36+
def get(self, key: Any, default: Any = None) -> Any:
37+
"""
38+
Retrieve an item from the cache.
39+
40+
Args:
41+
key: The key to look up in the cache.
42+
default: Value to return if key is not found (default: None).
43+
44+
Returns:
45+
The value associated with the key, or default if not found.
46+
"""
47+
pass
48+
49+
@abstractmethod
50+
def put(self, key: Any, value: Any) -> None:
51+
"""
52+
Store an item in the cache.
53+
54+
If the cache is at maxsize, this method should evict an item
55+
according to the cache's eviction policy (e.g., LFU, LRU, etc.).
56+
57+
Args:
58+
key: The key to store.
59+
value: The value to associate with the key.
60+
"""
61+
pass
62+
63+
@abstractmethod
64+
def size(self) -> int:
65+
"""
66+
Get the current number of items in the cache.
67+
68+
Returns:
69+
The number of items currently stored in the cache.
70+
"""
71+
pass
72+
73+
@abstractmethod
74+
def is_empty(self) -> bool:
75+
"""
76+
Check if the cache is empty.
77+
78+
Returns:
79+
True if the cache contains no items, False otherwise.
80+
"""
81+
pass
82+
83+
@abstractmethod
84+
def clear(self) -> None:
85+
"""
86+
Remove all items from the cache.
87+
88+
After calling this method, the cache should be empty.
89+
"""
90+
pass
91+
92+
def contains(self, key: Any) -> bool:
93+
"""
94+
Check if a key exists in the cache.
95+
96+
This is an optional method that can be overridden for efficiency.
97+
The default implementation uses get() to check existence.
98+
99+
Args:
100+
key: The key to check.
101+
102+
Returns:
103+
True if the key exists in the cache, False otherwise.
104+
"""
105+
# Default implementation - can be overridden for efficiency
106+
sentinel = object()
107+
return self.get(key, sentinel) is not sentinel
108+
109+
@property
110+
@abstractmethod
111+
def maxsize(self) -> int:
112+
"""
113+
Get the maximum size of the cache.
114+
115+
Returns:
116+
The maximum number of items the cache can hold.
117+
"""
118+
pass
119+
120+
def get_stats(self) -> dict:
121+
"""
122+
Get cache statistics.
123+
124+
Returns:
125+
Dictionary with cache statistics. The format and contents
126+
may vary by implementation. Common fields include:
127+
- hits: Number of cache hits
128+
- misses: Number of cache misses
129+
- evictions: Number of items evicted
130+
- hit_rate: Percentage of requests that were hits
131+
- current_size: Current number of items in cache
132+
- maxsize: Maximum size of the cache
133+
134+
The default implementation returns a message indicating that
135+
statistics tracking is not supported.
136+
"""
137+
return {
138+
"message": "Statistics tracking is not supported by this cache implementation"
139+
}
140+
141+
def reset_stats(self) -> None:
142+
"""
143+
Reset cache statistics.
144+
145+
This is an optional method that cache implementations can override
146+
if they support statistics tracking. The default implementation does nothing.
147+
"""
148+
# Default no-op implementation
149+
pass
150+
151+
def log_stats_now(self) -> None:
152+
"""
153+
Force immediate logging of cache statistics.
154+
155+
This is an optional method that cache implementations can override
156+
if they support statistics logging. The default implementation does nothing.
157+
158+
Implementations that support statistics logging should output the
159+
current cache statistics to their configured logging backend.
160+
"""
161+
# Default no-op implementation
162+
pass
163+
164+
def supports_stats_logging(self) -> bool:
165+
"""
166+
Check if this cache implementation supports statistics logging.
167+
168+
Returns:
169+
True if the cache supports statistics logging, False otherwise.
170+
171+
The default implementation returns False. Cache implementations
172+
that support statistics logging should override this to return True
173+
when logging is enabled.
174+
"""
175+
return False
176+
177+
async def get_or_compute(
178+
self, key: Any, compute_fn: Callable[[], Any], default: Any = None
179+
) -> Any:
180+
"""
181+
Atomically get a value from the cache or compute it if not present.
182+
183+
This method ensures that the compute function is called at most once
184+
even in the presence of concurrent requests for the same key.
185+
186+
Args:
187+
key: The key to look up
188+
compute_fn: Async function to compute the value if key is not found
189+
default: Value to return if compute_fn raises an exception
190+
191+
Returns:
192+
The cached value or the computed value
193+
194+
This is an optional method with a default implementation. Cache
195+
implementations should override this for better thread-safety guarantees.
196+
"""
197+
# Default implementation - not thread-safe for computation
198+
value = self.get(key)
199+
if value is not None:
200+
return value
201+
202+
try:
203+
computed_value = await compute_fn()
204+
self.put(key, computed_value)
205+
return computed_value
206+
except Exception:
207+
return default

0 commit comments

Comments
 (0)