Skip to content

Commit 2419dc2

Browse files
committed
Implement LlamaTrieCache into llama_cache.py: Optimize LlamaCache lookup from O(N) to O(K) using a Trie
1 parent d42cd32 commit 2419dc2

File tree

4 files changed

+166
-3
lines changed

4 files changed

+166
-3
lines changed

llama_cpp/llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
from .llama_grammar import LlamaGrammar
3535
from .llama_cache import (
3636
BaseLlamaCache,
37-
LlamaCache, # type: ignore
37+
LlamaCache, # type: ignore
3838
LlamaDiskCache, # type: ignore
39-
LlamaRAMCache, # type: ignore
39+
LlamaRAMCache, # type: ignore
40+
LlamaTrieCache, # type: ignore
4041
)
4142
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
4243
import llama_cpp.llama_cpp as llama_cpp

llama_cpp/llama_cache.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,164 @@ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
9797
self.cache_state.popitem(last=False)
9898

9999

100+
class TrieNode:
101+
"""A node in the prefix tree (Trie)."""
102+
def __init__(self):
103+
# Child nodes: {token_id: TrieNode}
104+
self.children: Dict[int, "TrieNode"] = {}
105+
# Stores the LlamaState if this node marks the end of a cached sequence.
106+
self.state: Optional["llama_cpp.llama.LlamaState"] = None
107+
108+
109+
class LlamaTrieCache(BaseLlamaCache):
110+
"""
111+
A Llama cache implementation using a Trie for O(K) prefix lookup
112+
and an OrderedDict for O(1) LRU eviction.
113+
114+
- K = length of the query key (number of tokens)
115+
- N = total number of items in the cache
116+
117+
This solves the O(N*K) lookup bottleneck of the linear scan cache.
118+
"""
119+
120+
def __init__(self, capacity_bytes: int = (2 << 30)):
121+
super().__init__(capacity_bytes)
122+
self.root = TrieNode() # The root node of the Trie
123+
self._current_size = 0 # O(1) tracking of cache size in bytes
124+
125+
# LRU Tracker:
126+
# Key: Cached token sequence (Tuple[int, ...])
127+
# Value: The *terminal* TrieNode for that key
128+
self.lru_tracker: OrderedDict[
129+
Tuple[int, ...], TrieNode
130+
] = OrderedDict()
131+
132+
@property
133+
def cache_size(self) -> int:
134+
"""Returns the current total size of the cache in bytes (O(1))."""
135+
return self._current_size
136+
137+
def _find_longest_prefix_node(
138+
self, key: Tuple[int, ...]
139+
) -> Tuple[Optional[TrieNode], Optional[Tuple[int, ...]]]:
140+
"""
141+
Finds the longest cached prefix for a given key in O(K) time.
142+
143+
Returns: (The matching TrieNode, The matching key)
144+
"""
145+
node = self.root
146+
longest_prefix_node: Optional[TrieNode] = None
147+
longest_prefix_key: Optional[Tuple[int, ...]] = None
148+
current_prefix: List[int] = []
149+
150+
# Check if the empty prefix (root) is cached
151+
if node.state is not None:
152+
longest_prefix_node = node
153+
longest_prefix_key = tuple(current_prefix)
154+
155+
for token in key:
156+
if token not in node.children:
157+
# Path ends, no further prefix matches
158+
break
159+
160+
node = node.children[token]
161+
current_prefix.append(token)
162+
163+
if node.state is not None:
164+
# Found a valid, longer prefix; update our best match
165+
longest_prefix_node = node
166+
longest_prefix_key = tuple(current_prefix)
167+
168+
return longest_prefix_node, longest_prefix_key
169+
170+
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
171+
"""
172+
Retrieves the state for the longest matching prefix in O(K) time.
173+
Updates the LRU status.
174+
"""
175+
key_tuple = tuple(key)
176+
node, prefix_key = self._find_longest_prefix_node(key_tuple)
177+
178+
if node is None or node.state is None or prefix_key is None:
179+
raise KeyError(f"Key prefix not found in cache for: {key_tuple}")
180+
181+
# Move the accessed key to the end (most recently used) in O(1)
182+
self.lru_tracker.move_to_end(prefix_key)
183+
184+
return node.state
185+
186+
def __contains__(self, key: Sequence[int]) -> bool:
187+
"""Checks if any prefix of the key is cached in O(K) time."""
188+
node, _ = self._find_longest_prefix_node(tuple(key))
189+
return node is not None
190+
191+
def _prune(self, key: Tuple[int, ...]):
192+
"""
193+
(Helper) Removes a key and its state from the Trie.
194+
Also removes empty parent nodes (branch pruning).
195+
"""
196+
path: List[Tuple[TrieNode, int]] = [] # Stores (parent_node, token)
197+
node = self.root
198+
199+
# 1. Find the node and record the path
200+
for token in key:
201+
if token not in node.children:
202+
return # Key not found
203+
path.append((node, token))
204+
node = node.children[token]
205+
206+
# 2. Remove the state
207+
if node.state is None:
208+
return # Node has no state
209+
210+
self._current_size -= node.state.llama_state_size
211+
node.state = None
212+
213+
# 3. Prune empty parent nodes backward
214+
for parent, token in reversed(path):
215+
child = parent.children[token]
216+
217+
# If the child node is now empty (no children, no state), delete it
218+
if not child.children and child.state is None:
219+
del parent.children[token]
220+
else:
221+
# Node is still in use, stop pruning
222+
break
223+
224+
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
225+
"""
226+
Adds a (key, state) pair to the cache in O(K) time.
227+
Handles LRU updates and eviction.
228+
"""
229+
key_tuple = tuple(key)
230+
231+
# 1. Find or create nodes for the key (O(K))
232+
node = self.root
233+
for token in key_tuple:
234+
node = node.children.setdefault(token, TrieNode())
235+
236+
# 2. Check if updating an existing item
237+
if node.state is not None:
238+
self._current_size -= node.state.llama_state_size
239+
240+
# 3. Set new state and update O(1) size
241+
node.state = value
242+
self._current_size += value.llama_state_size
243+
244+
# 4. Update LRU tracker (O(1))
245+
if key_tuple in self.lru_tracker:
246+
self.lru_tracker.move_to_end(key_tuple)
247+
else:
248+
self.lru_tracker[key_tuple] = node
249+
250+
# 5. Eviction logic
251+
while self._current_size > self.capacity_bytes and self.lru_tracker:
252+
# Get the least recently used item in O(1)
253+
evicted_key, _ = self.lru_tracker.popitem(last=False)
254+
255+
# Remove the evicted item from the Trie
256+
self._prune(evicted_key)
257+
100258
# Alias for backwards compatibility
101259
LlamaCache = LlamaRAMCache
102260

llama_cpp/server/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
323323
if settings.verbose:
324324
print(f"Using disk cache with size {settings.cache_size}")
325325
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
326+
elif settings.cache_type == "tire":
327+
if settings.verbose:
328+
print(f"Using tire cache with size {settings.cache_size}")
329+
cache = llama_cpp.LlamaTrieCache(capacity_bytes=settings.cache_size)
326330
else:
327331
if settings.verbose:
328332
print(f"Using ram cache with size {settings.cache_size}")

llama_cpp/server/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class ModelSettings(BaseSettings):
159159
default=False,
160160
description="Use a cache to reduce processing times for evaluated prompts.",
161161
)
162-
cache_type: Literal["ram", "disk"] = Field(
162+
cache_type: Literal["ram", "trie", "disk"] = Field(
163163
default="ram",
164164
description="The type of cache to use. Only used if cache is True.",
165165
)

0 commit comments

Comments
 (0)