@@ -97,6 +97,164 @@ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
9797 self .cache_state .popitem (last = False )
9898
9999
100+ class TrieNode :
101+ """A node in the prefix tree (Trie)."""
102+ def __init__ (self ):
103+ # Child nodes: {token_id: TrieNode}
104+ self .children : Dict [int , "TrieNode" ] = {}
105+ # Stores the LlamaState if this node marks the end of a cached sequence.
106+ self .state : Optional ["llama_cpp.llama.LlamaState" ] = None
107+
108+
109+ class LlamaTrieCache (BaseLlamaCache ):
110+ """
111+ A Llama cache implementation using a Trie for O(K) prefix lookup
112+ and an OrderedDict for O(1) LRU eviction.
113+
114+ - K = length of the query key (number of tokens)
115+ - N = total number of items in the cache
116+
117+ This solves the O(N*K) lookup bottleneck of the linear scan cache.
118+ """
119+
120+ def __init__ (self , capacity_bytes : int = (2 << 30 )):
121+ super ().__init__ (capacity_bytes )
122+ self .root = TrieNode () # The root node of the Trie
123+ self ._current_size = 0 # O(1) tracking of cache size in bytes
124+
125+ # LRU Tracker:
126+ # Key: Cached token sequence (Tuple[int, ...])
127+ # Value: The *terminal* TrieNode for that key
128+ self .lru_tracker : OrderedDict [
129+ Tuple [int , ...], TrieNode
130+ ] = OrderedDict ()
131+
132+ @property
133+ def cache_size (self ) -> int :
134+ """Returns the current total size of the cache in bytes (O(1))."""
135+ return self ._current_size
136+
137+ def _find_longest_prefix_node (
138+ self , key : Tuple [int , ...]
139+ ) -> Tuple [Optional [TrieNode ], Optional [Tuple [int , ...]]]:
140+ """
141+ Finds the longest cached prefix for a given key in O(K) time.
142+
143+ Returns: (The matching TrieNode, The matching key)
144+ """
145+ node = self .root
146+ longest_prefix_node : Optional [TrieNode ] = None
147+ longest_prefix_key : Optional [Tuple [int , ...]] = None
148+ current_prefix : List [int ] = []
149+
150+ # Check if the empty prefix (root) is cached
151+ if node .state is not None :
152+ longest_prefix_node = node
153+ longest_prefix_key = tuple (current_prefix )
154+
155+ for token in key :
156+ if token not in node .children :
157+ # Path ends, no further prefix matches
158+ break
159+
160+ node = node .children [token ]
161+ current_prefix .append (token )
162+
163+ if node .state is not None :
164+ # Found a valid, longer prefix; update our best match
165+ longest_prefix_node = node
166+ longest_prefix_key = tuple (current_prefix )
167+
168+ return longest_prefix_node , longest_prefix_key
169+
170+ def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama.LlamaState" :
171+ """
172+ Retrieves the state for the longest matching prefix in O(K) time.
173+ Updates the LRU status.
174+ """
175+ key_tuple = tuple (key )
176+ node , prefix_key = self ._find_longest_prefix_node (key_tuple )
177+
178+ if node is None or node .state is None or prefix_key is None :
179+ raise KeyError (f"Key prefix not found in cache for: { key_tuple } " )
180+
181+ # Move the accessed key to the end (most recently used) in O(1)
182+ self .lru_tracker .move_to_end (prefix_key )
183+
184+ return node .state
185+
186+ def __contains__ (self , key : Sequence [int ]) -> bool :
187+ """Checks if any prefix of the key is cached in O(K) time."""
188+ node , _ = self ._find_longest_prefix_node (tuple (key ))
189+ return node is not None
190+
191+ def _prune (self , key : Tuple [int , ...]):
192+ """
193+ (Helper) Removes a key and its state from the Trie.
194+ Also removes empty parent nodes (branch pruning).
195+ """
196+ path : List [Tuple [TrieNode , int ]] = [] # Stores (parent_node, token)
197+ node = self .root
198+
199+ # 1. Find the node and record the path
200+ for token in key :
201+ if token not in node .children :
202+ return # Key not found
203+ path .append ((node , token ))
204+ node = node .children [token ]
205+
206+ # 2. Remove the state
207+ if node .state is None :
208+ return # Node has no state
209+
210+ self ._current_size -= node .state .llama_state_size
211+ node .state = None
212+
213+ # 3. Prune empty parent nodes backward
214+ for parent , token in reversed (path ):
215+ child = parent .children [token ]
216+
217+ # If the child node is now empty (no children, no state), delete it
218+ if not child .children and child .state is None :
219+ del parent .children [token ]
220+ else :
221+ # Node is still in use, stop pruning
222+ break
223+
224+ def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama.LlamaState" ):
225+ """
226+ Adds a (key, state) pair to the cache in O(K) time.
227+ Handles LRU updates and eviction.
228+ """
229+ key_tuple = tuple (key )
230+
231+ # 1. Find or create nodes for the key (O(K))
232+ node = self .root
233+ for token in key_tuple :
234+ node = node .children .setdefault (token , TrieNode ())
235+
236+ # 2. Check if updating an existing item
237+ if node .state is not None :
238+ self ._current_size -= node .state .llama_state_size
239+
240+ # 3. Set new state and update O(1) size
241+ node .state = value
242+ self ._current_size += value .llama_state_size
243+
244+ # 4. Update LRU tracker (O(1))
245+ if key_tuple in self .lru_tracker :
246+ self .lru_tracker .move_to_end (key_tuple )
247+ else :
248+ self .lru_tracker [key_tuple ] = node
249+
250+ # 5. Eviction logic
251+ while self ._current_size > self .capacity_bytes and self .lru_tracker :
252+ # Get the least recently used item in O(1)
253+ evicted_key , _ = self .lru_tracker .popitem (last = False )
254+
255+ # Remove the evicted item from the Trie
256+ self ._prune (evicted_key )
257+
100258# Alias for backwards compatibility
101259LlamaCache = LlamaRAMCache
102260
0 commit comments