55from redis import Redis
66
77from redisvl .extensions .session_manager import BaseSessionManager
8- from redisvl .redis .connection import RedisConnectionFactory
8+ from redisvl .index import SearchIndex
9+ from redisvl .query import FilterQuery
10+ from redisvl .query .filter import Tag
11+ from redisvl .schema .schema import IndexSchema
12+
13+
14+ class StandardSessionIndexSchema (IndexSchema ):
15+
16+ @classmethod
17+ def from_params (cls , name : str , prefix : str ):
18+
19+ return cls (
20+ index = {"name" : name , "prefix" : prefix }, # type: ignore
21+ fields = [ # type: ignore
22+ {"name" : "role" , "type" : "text" },
23+ {"name" : "content" , "type" : "text" },
24+ {"name" : "tool_call_id" , "type" : "text" },
25+ {"name" : "timestamp" , "type" : "numeric" },
26+ {"name" : "session_tag" , "type" : "tag" },
27+ {"name" : "user_tag" , "type" : "tag" },
28+ ],
29+ )
930
1031
1132class StandardSessionManager (BaseSessionManager ):
33+ session_field_name : str = "session_tag"
34+ user_field_name : str = "user_tag"
1235
1336 def __init__ (
1437 self ,
1538 name : str ,
1639 session_tag : str ,
1740 user_tag : str ,
41+ prefix : Optional [str ] = None ,
1842 redis_client : Optional [Redis ] = None ,
1943 redis_url : str = "redis://localhost:6379" ,
2044 connection_kwargs : Dict [str , Any ] = {},
@@ -29,9 +53,11 @@ def __init__(
2953
3054 Args:
3155 name (str): The name of the session manager index.
32- session_tag (str): Tag to be added to entries to link to a specific
56+ session_tag (Optional[ str] ): Tag to be added to entries to link to a specific
3357 session.
34- user_tag (str): Tag to be added to entries to link to a specific user.
58+ user_tag (Optional[str]): Tag to be added to entries to link to a specific user.
59+ prefix (Optional[str]): Prefix for the keys for this session data.
60+ Defaults to None and will be replaced with the index name.
3561 redis_client (Optional[Redis]): A Redis client instance. Defaults to
3662 None.
3763 redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
@@ -44,14 +70,18 @@ def __init__(
4470 """
4571 super ().__init__ (name , session_tag , user_tag )
4672
73+ prefix = prefix or name
74+
75+ schema = StandardSessionIndexSchema .from_params (name , prefix )
76+ self ._index = SearchIndex (schema = schema )
77+
4778 # handle redis connection
4879 if redis_client :
49- self ._client = redis_client
80+ self ._index . set_client ( redis_client )
5081 elif redis_url :
51- self ._client = RedisConnectionFactory .get_redis_connection (
52- redis_url , ** connection_kwargs
53- )
54- RedisConnectionFactory .validate_sync_redis (self ._client )
82+ self ._index .connect (redis_url = redis_url , ** connection_kwargs )
83+
84+ self ._index .create (overwrite = False )
5585
5686 self .set_scope (session_tag , user_tag )
5787
@@ -63,27 +93,35 @@ def set_scope(
6393 """Set the filter to apply to queries based on the desired scope.
6494
6595 This new scope persists until another call to set_scope is made, or if
66- scope is specified in calls to get_recent.
96+ scope specified in calls to get_recent or get_relevant .
6797
6898 Args:
6999 session_tag (str): Id of the specific session to filter to. Default is
70- None, which means session_tag will be unchanged .
100+ None, which means all sessions will be in scope .
71101 user_tag (str): Id of the specific user to filter to. Default is None,
72- which means user_tag will be unchanged .
102+ which means all users will be in scope .
73103 """
74104 if not (session_tag or user_tag ):
75105 return
76-
77106 self ._session_tag = session_tag or self ._session_tag
78107 self ._user_tag = user_tag or self ._user_tag
108+ tag_filter = Tag (self .user_field_name ) == []
109+ if user_tag :
110+ tag_filter = tag_filter & (Tag (self .user_field_name ) == self ._user_tag )
111+ if session_tag :
112+ tag_filter = tag_filter & (
113+ Tag (self .session_field_name ) == self ._session_tag
114+ )
115+
116+ self ._tag_filter = tag_filter
79117
80118 def clear (self ) -> None :
81119 """Clears the chat session history."""
82- self ._client . delete ( self . key )
120+ self ._index . clear ( )
83121
84122 def delete (self ) -> None :
85- """Clears the chat session history ."""
86- self ._client .delete (self . key )
123+ """Clear all conversation keys and remove the search index ."""
124+ self ._index .delete (drop = True )
87125
88126 def drop (self , id_field : Optional [str ] = None ) -> None :
89127 """Remove a specific exchange from the conversation history.
@@ -93,19 +131,36 @@ def drop(self, id_field: Optional[str] = None) -> None:
93131 If None then the last entry is deleted.
94132 """
95133 if id_field :
96- messages = self ._client .lrange (self .key , 0 , - 1 )
97- messages = [json .loads (msg ) for msg in messages ]
98- messages = [msg for msg in messages if msg ["id_field" ] != id_field ]
99- messages = [json .dumps (msg ) for msg in messages ]
100- self .clear ()
101- self ._client .rpush (self .key , * messages )
134+ sep = self ._index .key_separator
135+ key = sep .join ([self ._index .schema .index .name , id_field ])
102136 else :
103- self ._client .rpop (self .key )
137+ key = self .get_recent (top_k = 1 , raw = True )[0 ]["id" ] # type: ignore
138+ self ._index .client .delete (key ) # type: ignore
104139
105140 @property
106141 def messages (self ) -> Union [List [str ], List [Dict [str , str ]]]:
107142 """Returns the full chat history."""
108- return self .get_recent (top_k = - 1 )
143+ # TODO raw or as_text?
144+ return_fields = [
145+ self .id_field_name ,
146+ self .session_field_name ,
147+ self .user_field_name ,
148+ self .role_field_name ,
149+ self .content_field_name ,
150+ self .tool_field_name ,
151+ self .timestamp_field_name ,
152+ ]
153+
154+ query = FilterQuery (
155+ filter_expression = self ._tag_filter ,
156+ return_fields = return_fields ,
157+ )
158+
159+ sorted_query = query .query
160+ sorted_query .sort_by (self .timestamp_field_name , asc = True )
161+ hits = self ._index .search (sorted_query , query .params ).docs
162+
163+ return self ._format_context (hits , as_text = False )
109164
110165 def get_recent (
111166 self ,
@@ -119,7 +174,6 @@ def get_recent(
119174
120175 Args:
121176 top_k (int): The number of previous messages to return. Default is 5.
122- To get all messages set top_k = -1.
123177 session_tag (str): Tag to be added to entries to link to a specific
124178 session.
125179 user_tag (str): Tag to be added to entries to link to a specific user.
@@ -133,24 +187,35 @@ def get_recent(
133187 or list of strings if as_text is false.
134188
135189 Raises:
136- ValueError: if top_k is not an integer greater than or equal to -1 .
190+ ValueError: if top_k is not an integer greater than or equal to 0 .
137191 """
138- if type (top_k ) != int or top_k < - 1 :
139- raise ValueError ("top_k must be an integer greater than or equal to -1" )
140- if top_k == 0 :
141- return []
142- elif top_k == - 1 :
143- top_k = 0
192+ if type (top_k ) != int or top_k < 0 :
193+ raise ValueError ("top_k must be an integer greater than or equal to 0" )
194+
144195 self .set_scope (session_tag , user_tag )
145- messages = self ._client .lrange (self .key , - top_k , - 1 )
146- messages = [json .loads (msg ) for msg in messages ]
147- if raw :
148- return messages
149- return self ._format_context (messages , as_text )
196+ return_fields = [
197+ self .id_field_name ,
198+ self .session_field_name ,
199+ self .user_field_name ,
200+ self .role_field_name ,
201+ self .content_field_name ,
202+ self .tool_field_name ,
203+ self .timestamp_field_name ,
204+ ]
205+
206+ query = FilterQuery (
207+ filter_expression = self ._tag_filter ,
208+ return_fields = return_fields ,
209+ num_results = top_k ,
210+ )
150211
151- @property
152- def key (self ):
153- return ":" .join ([self ._name , self ._user_tag , self ._session_tag ])
212+ sorted_query = query .query
213+ sorted_query .sort_by (self .timestamp_field_name , asc = False )
214+ hits = self ._index .search (sorted_query , query .params ).docs
215+
216+ if raw :
217+ return hits [::- 1 ]
218+ return self ._format_context (hits [::- 1 ], as_text )
154219
155220 def store (self , prompt : str , response : str ) -> None :
156221 """Insert a prompt:response pair into the session memory. A timestamp
@@ -162,7 +227,10 @@ def store(self, prompt: str, response: str) -> None:
162227 response (str): The corresponding LLM response.
163228 """
164229 self .add_messages (
165- [{"role" : "user" , "content" : prompt }, {"role" : "llm" , "content" : response }]
230+ [
231+ {self .role_field_name : "user" , self .content_field_name : prompt },
232+ {self .role_field_name : "llm" , self .content_field_name : response },
233+ ]
166234 )
167235
168236 def add_messages (self , messages : List [Dict [str , str ]]) -> None :
@@ -173,23 +241,23 @@ def add_messages(self, messages: List[Dict[str, str]]) -> None:
173241 Args:
174242 messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
175243 """
244+ sep = self ._index .key_separator
176245 payloads = []
177246 for message in messages :
178247 timestamp = time ()
248+ id_field = sep .join ([self ._user_tag , self ._session_tag , str (timestamp )])
179249 payload = {
180- self .id_field_name : ":" .join (
181- [self ._user_tag , self ._session_tag , str (timestamp )]
182- ),
250+ self .id_field_name : id_field ,
183251 self .role_field_name : message [self .role_field_name ],
184252 self .content_field_name : message [self .content_field_name ],
185253 self .timestamp_field_name : timestamp ,
254+ self .session_field_name : self ._session_tag ,
255+ self .user_field_name : self ._user_tag ,
186256 }
187257 if self .tool_field_name in message :
188258 payload .update ({self .tool_field_name : message [self .tool_field_name ]})
189-
190- payloads .append (json .dumps (payload ))
191-
192- self ._client .rpush (self .key , * payloads )
259+ payloads .append (payload )
260+ self ._index .load (data = payloads , id_field = self .id_field_name )
193261
194262 def add_message (self , message : Dict [str , str ]) -> None :
195263 """Insert a single prompt or response into the session memory.
0 commit comments