2424from langgraph .checkpoint .redis .ashallow import AsyncShallowRedisSaver
2525from langgraph .checkpoint .redis .base import BaseRedisSaver
2626from langgraph .checkpoint .redis .shallow import ShallowRedisSaver
27+ from langgraph .checkpoint .redis .util import (
28+ EMPTY_ID_SENTINEL ,
29+ from_storage_safe_id ,
30+ from_storage_safe_str ,
31+ to_storage_safe_id ,
32+ to_storage_safe_str ,
33+ )
2734from langgraph .checkpoint .redis .version import __lib_name__ , __version__
2835
2936
@@ -79,12 +86,21 @@ def list(
7986 filter_expression = []
8087 if config :
8188 filter_expression .append (
82- Tag ("thread_id" ) == config ["configurable" ]["thread_id" ]
89+ Tag ("thread_id" )
90+ == to_storage_safe_id (config ["configurable" ]["thread_id" ])
8391 )
92+
93+ # Reproducing the logic from the Postgres implementation, we'll
94+ # search for checkpoints with any namespace, including an empty
95+ # string, while `checkpoint_id` has to have a value.
8496 if checkpoint_ns := config ["configurable" ].get ("checkpoint_ns" ):
85- filter_expression .append (Tag ("checkpoint_ns" ) == checkpoint_ns )
97+ filter_expression .append (
98+ Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns )
99+ )
86100 if checkpoint_id := get_checkpoint_id (config ):
87- filter_expression .append (Tag ("checkpoint_id" ) == checkpoint_id )
101+ filter_expression .append (
102+ Tag ("checkpoint_id" ) == to_storage_safe_id (checkpoint_id )
103+ )
88104
89105 if filter :
90106 for k , v in filter .items ():
@@ -122,9 +138,10 @@ def list(
122138
123139 # Process the results
124140 for doc in results .docs :
125- thread_id = str (getattr (doc , "thread_id" , "" ))
126- checkpoint_ns = str (getattr (doc , "checkpoint_ns" , "" ))
127- checkpoint_id = str (getattr (doc , "checkpoint_id" , "" ))
141+ thread_id = from_storage_safe_id (doc ["thread_id" ])
142+ checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
143+ checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
144+ parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
128145
129146 # Fetch channel_values
130147 channel_values = self .get_channel_values (
@@ -135,11 +152,11 @@ def list(
135152
136153 # Fetch pending_sends from parent checkpoint
137154 pending_sends = []
138- if doc [ " parent_checkpoint_id" ] :
155+ if parent_checkpoint_id :
139156 pending_sends = self ._load_pending_sends (
140157 thread_id = thread_id ,
141158 checkpoint_ns = checkpoint_ns ,
142- parent_checkpoint_id = doc [ " parent_checkpoint_id" ] ,
159+ parent_checkpoint_id = parent_checkpoint_id ,
143160 )
144161
145162 # Fetch and parse metadata
@@ -163,7 +180,7 @@ def list(
163180 "configurable" : {
164181 "thread_id" : thread_id ,
165182 "checkpoint_ns" : checkpoint_ns ,
166- "checkpoint_id" : doc [ " checkpoint_id" ] ,
183+ "checkpoint_id" : checkpoint_id ,
167184 }
168185 }
169186
@@ -194,49 +211,60 @@ def put(
194211 ) -> RunnableConfig :
195212 """Store a checkpoint to Redis."""
196213 configurable = config ["configurable" ].copy ()
214+
197215 thread_id = configurable .pop ("thread_id" )
198216 checkpoint_ns = configurable .pop ("checkpoint_ns" )
199- checkpoint_id = configurable .pop (
200- "checkpoint_id" , configurable .pop ("thread_ts" , None )
217+ checkpoint_id = checkpoint_id = configurable .pop (
218+ "checkpoint_id" , configurable .pop ("thread_ts" , "" )
201219 )
202220
221+ # For values we store in Redis, we need to convert empty strings to the
222+ # sentinel value.
223+ storage_safe_thread_id = to_storage_safe_id (thread_id )
224+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
225+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
226+
203227 copy = checkpoint .copy ()
228+ # When we return the config, we need to preserve empty strings that
229+ # were passed in, instead of the sentinel value.
204230 next_config = {
205231 "configurable" : {
206232 "thread_id" : thread_id ,
207233 "checkpoint_ns" : checkpoint_ns ,
208- "checkpoint_id" : checkpoint [ "id" ] ,
234+ "checkpoint_id" : checkpoint_id ,
209235 }
210236 }
211237
212- # Store checkpoint data
238+ # Store checkpoint data.
213239 checkpoint_data = {
214- "thread_id" : thread_id ,
215- "checkpoint_ns" : checkpoint_ns ,
216- "checkpoint_id" : checkpoint [ "id" ] ,
217- "parent_checkpoint_id" : checkpoint_id ,
240+ "thread_id" : storage_safe_thread_id ,
241+ "checkpoint_ns" : storage_safe_checkpoint_ns ,
242+ "checkpoint_id" : storage_safe_checkpoint_id ,
243+ "parent_checkpoint_id" : storage_safe_checkpoint_id ,
218244 "checkpoint" : self ._dump_checkpoint (copy ),
219245 "metadata" : self ._dump_metadata (metadata ),
220246 }
221247
222248 # store at top-level for filters in list()
223249 if all (key in metadata for key in ["source" , "step" ]):
224250 checkpoint_data ["source" ] = metadata ["source" ]
225- checkpoint_data ["step" ] = metadata ["step" ]
251+ checkpoint_data ["step" ] = metadata ["step" ] # type: ignore
226252
227253 self .checkpoints_index .load (
228254 [checkpoint_data ],
229255 keys = [
230256 BaseRedisSaver ._make_redis_checkpoint_key (
231- thread_id , checkpoint_ns , checkpoint ["id" ]
257+ storage_safe_thread_id ,
258+ storage_safe_checkpoint_ns ,
259+ storage_safe_checkpoint_id ,
232260 )
233261 ],
234262 )
235263
236- # Store blob values
264+ # Store blob values.
237265 blobs = self ._dump_blobs (
238- thread_id ,
239- checkpoint_ns ,
266+ storage_safe_thread_id ,
267+ storage_safe_checkpoint_ns ,
240268 copy .get ("channel_values" , {}),
241269 new_versions ,
242270 )
@@ -258,19 +286,22 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
258286 Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
259287 """
260288 thread_id = config ["configurable" ]["thread_id" ]
261- checkpoint_id = str ( get_checkpoint_id (config ) )
289+ checkpoint_id = get_checkpoint_id (config )
262290 checkpoint_ns = config ["configurable" ].get ("checkpoint_ns" , "" )
263291
264- if checkpoint_id :
292+ ascending = True
293+
294+ if checkpoint_id and checkpoint_id != EMPTY_ID_SENTINEL :
265295 checkpoint_filter_expression = (
266- (Tag ("thread_id" ) == thread_id )
267- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
268- & (Tag ("checkpoint_id" ) == checkpoint_id )
296+ (Tag ("thread_id" ) == to_storage_safe_id ( thread_id ) )
297+ & (Tag ("checkpoint_ns" ) == to_storage_safe_str ( checkpoint_ns ) )
298+ & (Tag ("checkpoint_id" ) == to_storage_safe_id ( checkpoint_id ) )
269299 )
270300 else :
271- checkpoint_filter_expression = (Tag ("thread_id" ) == thread_id ) & (
272- Tag ("checkpoint_ns" ) == checkpoint_ns
273- )
301+ checkpoint_filter_expression = (
302+ Tag ("thread_id" ) == to_storage_safe_id (thread_id )
303+ ) & (Tag ("checkpoint_ns" ) == to_storage_safe_str (checkpoint_ns ))
304+ ascending = False
274305
275306 # Construct the query
276307 checkpoints_query = FilterQuery (
@@ -285,29 +316,33 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
285316 ],
286317 num_results = 1 ,
287318 )
288- checkpoints_query .sort_by ("checkpoint_id" , asc = False )
319+ checkpoints_query .sort_by ("checkpoint_id" , asc = ascending )
289320
290321 # Execute the query
291322 results = self .checkpoints_index .search (checkpoints_query )
292323 if not results .docs :
293324 return None
294325
295326 doc = results .docs [0 ]
327+ doc_thread_id = from_storage_safe_id (doc ["thread_id" ])
328+ doc_checkpoint_ns = from_storage_safe_str (doc ["checkpoint_ns" ])
329+ doc_checkpoint_id = from_storage_safe_id (doc ["checkpoint_id" ])
330+ doc_parent_checkpoint_id = from_storage_safe_id (doc ["parent_checkpoint_id" ])
296331
297332 # Fetch channel_values
298333 channel_values = self .get_channel_values (
299- thread_id = doc [ "thread_id" ] ,
300- checkpoint_ns = doc [ "checkpoint_ns" ] ,
301- checkpoint_id = doc [ "checkpoint_id" ] ,
334+ thread_id = doc_thread_id ,
335+ checkpoint_ns = doc_checkpoint_ns ,
336+ checkpoint_id = doc_checkpoint_id ,
302337 )
303338
304339 # Fetch pending_sends from parent checkpoint
305340 pending_sends = []
306- if doc [ "parent_checkpoint_id" ] :
341+ if doc_parent_checkpoint_id :
307342 pending_sends = self ._load_pending_sends (
308- thread_id = thread_id ,
309- checkpoint_ns = checkpoint_ns ,
310- parent_checkpoint_id = doc [ "parent_checkpoint_id" ] ,
343+ thread_id = doc_thread_id ,
344+ checkpoint_ns = doc_checkpoint_ns ,
345+ parent_checkpoint_id = doc_parent_checkpoint_id ,
311346 )
312347
313348 # Fetch and parse metadata
@@ -329,7 +364,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
329364 "configurable" : {
330365 "thread_id" : thread_id ,
331366 "checkpoint_ns" : checkpoint_ns ,
332- "checkpoint_id" : doc [ "checkpoint_id" ] ,
367+ "checkpoint_id" : doc_checkpoint_id ,
333368 }
334369 }
335370
@@ -340,7 +375,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
340375 )
341376
342377 pending_writes = self ._load_pending_writes (
343- thread_id , checkpoint_ns , checkpoint_id
378+ thread_id , checkpoint_ns , doc_checkpoint_id
344379 )
345380
346381 return CheckpointTuple (
@@ -379,10 +414,14 @@ def get_channel_values(
379414 self , thread_id : str , checkpoint_ns : str = "" , checkpoint_id : str = ""
380415 ) -> dict [str , Any ]:
381416 """Retrieve channel_values dictionary with properly constructed message objects."""
417+ storage_safe_thread_id = to_storage_safe_id (thread_id )
418+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
419+ storage_safe_checkpoint_id = to_storage_safe_id (checkpoint_id )
420+
382421 checkpoint_query = FilterQuery (
383- filter_expression = (Tag ("thread_id" ) == thread_id )
384- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
385- & (Tag ("checkpoint_id" ) == checkpoint_id ),
422+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
423+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
424+ & (Tag ("checkpoint_id" ) == storage_safe_checkpoint_id ),
386425 return_fields = ["$.checkpoint.channel_versions" ],
387426 num_results = 1 ,
388427 )
@@ -400,8 +439,8 @@ def get_channel_values(
400439 channel_values = {}
401440 for channel , version in channel_versions .items ():
402441 blob_query = FilterQuery (
403- filter_expression = (Tag ("thread_id" ) == thread_id )
404- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
442+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
443+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
405444 & (Tag ("channel" ) == channel )
406445 & (Tag ("version" ) == version ),
407446 return_fields = ["type" , "$.blob" ],
@@ -437,11 +476,15 @@ def _load_pending_sends(
437476 Returns:
438477 List of (type, blob) tuples representing pending sends
439478 """
479+ storage_safe_thread_id = to_storage_safe_str (thread_id )
480+ storage_safe_checkpoint_ns = to_storage_safe_str (checkpoint_ns )
481+ storage_safe_parent_checkpoint_id = to_storage_safe_str (parent_checkpoint_id )
482+
440483 # Query checkpoint_writes for parent checkpoint's TASKS channel
441484 parent_writes_query = FilterQuery (
442- filter_expression = (Tag ("thread_id" ) == thread_id )
443- & (Tag ("checkpoint_ns" ) == checkpoint_ns )
444- & (Tag ("checkpoint_id" ) == parent_checkpoint_id )
485+ filter_expression = (Tag ("thread_id" ) == storage_safe_thread_id )
486+ & (Tag ("checkpoint_ns" ) == storage_safe_checkpoint_ns )
487+ & (Tag ("checkpoint_id" ) == storage_safe_parent_checkpoint_id )
445488 & (Tag ("channel" ) == TASKS ),
446489 return_fields = ["type" , "blob" , "task_path" , "task_id" , "idx" ],
447490 num_results = 100 , # Adjust as needed
0 commit comments