11import json
2+ import os
23
34from typing import TYPE_CHECKING
45
56from memos .api .product_models import APISearchRequest
67from memos .configs .mem_scheduler import GeneralSchedulerConfig
78from memos .log import get_logger
89from memos .mem_cube .general import GeneralMemCube
10+ from memos .mem_cube .navie import NaiveMemCube
911from memos .mem_scheduler .general_modules .api_misc import SchedulerAPIModule
1012from memos .mem_scheduler .general_scheduler import GeneralScheduler
1113from memos .mem_scheduler .schemas .general_schemas import (
2325
2426if TYPE_CHECKING :
2527 from memos .mem_scheduler .schemas .monitor_schemas import MemoryMonitorItem
28+ from memos .memories .textual .tree_text_memory .retrieve .searcher import Searcher
2629 from memos .reranker .http_bge import HTTPBGEReranker
2730
2831
@@ -34,43 +37,19 @@ class OptimizedScheduler(GeneralScheduler):
3437
3538 def __init__ (self , config : GeneralSchedulerConfig ):
3639 super ().__init__ (config )
37- self .api_module = SchedulerAPIModule ()
40+ self .window_size = int (os .getenv ("API_SEARCH_WINDOW_SIZE" , 5 ))
41+ self .history_memory_turns = int (os .getenv ("API_SEARCH_HISTORY_TURNS" , 5 ))
42+
43+ self .api_module = SchedulerAPIModule (
44+ window_size = self .window_size ,
45+ history_memory_turns = self .history_memory_turns ,
46+ )
3847 self .register_handlers (
3948 {
4049 API_MIX_SEARCH_LABEL : self ._api_mix_search_message_consumer ,
4150 }
4251 )
4352
44- def search_memories (
45- self ,
46- search_req : APISearchRequest ,
47- user_context : UserContext ,
48- mem_cube : GeneralMemCube ,
49- mode : SearchMode ,
50- ):
51- """Fine search memories function copied from server_router to avoid circular import"""
52- target_session_id = search_req .session_id
53- if not target_session_id :
54- target_session_id = "default_session"
55- search_filter = {"session_id" : search_req .session_id } if search_req .session_id else None
56-
57- # Create MemCube and perform search
58- search_results = mem_cube .text_mem .search (
59- query = search_req .query ,
60- user_name = user_context .mem_cube_id ,
61- top_k = search_req .top_k ,
62- mode = mode ,
63- manual_close_internet = not search_req .internet_search ,
64- moscube = search_req .moscube ,
65- search_filter = search_filter ,
66- info = {
67- "user_id" : search_req .user_id ,
68- "session_id" : target_session_id ,
69- "chat_history" : search_req .chat_history ,
70- },
71- )
72- return search_results
73-
7453 def submit_memory_history_async_task (
7554 self ,
7655 search_req : APISearchRequest ,
@@ -110,6 +89,36 @@ def submit_memory_history_async_task(
11089 logger .info (f"Submitted async fine search task for user { search_req .user_id } " )
11190 return async_task_id
11291
92+ def search_memories (
93+ self ,
94+ search_req : APISearchRequest ,
95+ user_context : UserContext ,
96+ mem_cube : NaiveMemCube ,
97+ mode : SearchMode ,
98+ ):
99+ """Fine search memories function copied from server_router to avoid circular import"""
100+ target_session_id = search_req .session_id
101+ if not target_session_id :
102+ target_session_id = "default_session"
103+ search_filter = {"session_id" : search_req .session_id } if search_req .session_id else None
104+
105+ # Create MemCube and perform search
106+ search_results = mem_cube .text_mem .search (
107+ query = search_req .query ,
108+ user_name = user_context .mem_cube_id ,
109+ top_k = search_req .top_k ,
110+ mode = mode ,
111+ manual_close_internet = not search_req .internet_search ,
112+ moscube = search_req .moscube ,
113+ search_filter = search_filter ,
114+ info = {
115+ "user_id" : search_req .user_id ,
116+ "session_id" : target_session_id ,
117+ "chat_history" : search_req .chat_history ,
118+ },
119+ )
120+ return search_results
121+
113122 def mix_search_memories (
114123 self ,
115124 search_req : APISearchRequest ,
@@ -122,12 +131,33 @@ def mix_search_memories(
122131 # Get mem_cube for fast search
123132 mem_cube = self .current_mem_cube
124133
125- # Perform fast search
126- fast_memories = self .search_memories (
127- search_req = search_req ,
128- user_context = user_context ,
129- mem_cube = mem_cube ,
134+ target_session_id = search_req .session_id
135+ if not target_session_id :
136+ target_session_id = "default_session"
137+ search_filter = {"session_id" : search_req .session_id } if search_req .session_id else None
138+
139+ text_mem : TreeTextMemory = mem_cube .text_mem
140+ searcher : Searcher = text_mem .get_searcher (
141+ manual_close_internet = not search_req .internet_search ,
142+ moscube = False ,
143+ )
144+ # Rerank Memories - reranker expects TextualMemoryItem objects
145+ reranker : HTTPBGEReranker = text_mem .reranker
146+ info = {
147+ "user_id" : search_req .user_id ,
148+ "session_id" : target_session_id ,
149+ "chat_history" : search_req .chat_history ,
150+ }
151+
152+ fast_retrieved_memories = searcher .retrieve (
153+ query = search_req .query ,
154+ user_name = user_context .mem_cube_id ,
155+ top_k = search_req .top_k ,
130156 mode = SearchMode .FAST ,
157+ manual_close_internet = not search_req .internet_search ,
158+ moscube = search_req .moscube ,
159+ search_filter = search_filter ,
160+ info = info ,
131161 )
132162
133163 self .submit_memory_history_async_task (
@@ -136,76 +166,69 @@ def mix_search_memories(
136166 )
137167
138168 # Try to get pre-computed fine memories if available
139- pre_fine_memories = self .api_module .get_pre_memories (
140- user_id = search_req .user_id , mem_cube_id = user_context .mem_cube_id
169+ history_memories = self .api_module .get_history_memories (
170+ user_id = search_req .user_id ,
171+ mem_cube_id = user_context .mem_cube_id ,
172+ turns = self .history_memory_turns ,
141173 )
142- if not pre_fine_memories :
174+ if not history_memories :
175+ fast_memories = searcher .post_retrieve (
176+ retrieved_results = fast_retrieved_memories ,
177+ top_k = search_req .top_k ,
178+ user_name = user_context .mem_cube_id ,
179+ info = info ,
180+ )
143181 # Format fast memories for return
144182 formatted_memories = [format_textual_memory_item (data ) for data in fast_memories ]
145183 return formatted_memories
146184
147- # Merge fast and pre-computed fine memories (both are TextualMemoryItem objects)
148- combined_memories = fast_memories + pre_fine_memories
149- # Remove duplicates based on memory content
150- seen_contents = set ()
151- unique_memories = []
152- for memory in combined_memories :
153- # Both fast_memories and pre_fine_memories are TextualMemoryItem objects
154- content_key = memory .memory # Use .memory attribute instead of .get("content", "")
155- if content_key not in seen_contents :
156- seen_contents .add (content_key )
157- unique_memories .append (memory )
158-
159- # Rerank Memories - reranker expects TextualMemoryItem objects
160- reranker : HTTPBGEReranker = mem_cube .text_mem .reranker
161-
162- # Use search_req parameters for reranking
163- target_session_id = search_req .session_id
164- if not target_session_id :
165- target_session_id = "default_session"
166- search_filter = {"session_id" : search_req .session_id } if search_req .session_id else None
167-
168- sorted_results = reranker .rerank (
185+ sorted_history_memories = reranker .rerank (
169186 query = search_req .query , # Use search_req.query instead of undefined query
170- graph_results = unique_memories , # Pass TextualMemoryItem objects directly
187+ graph_results = history_memories , # Pass TextualMemoryItem objects directly
171188 top_k = search_req .top_k , # Use search_req.top_k instead of undefined top_k
172189 search_filter = search_filter ,
173190 )
174191
192+ sorted_results = fast_retrieved_memories + sorted_history_memories
193+ final_results = searcher .post_retrieve (
194+ retrieved_results = sorted_results ,
195+ top_k = search_req .top_k ,
196+ user_name = user_context .mem_cube_id ,
197+ info = info ,
198+ )
199+
175200 formatted_memories = [
176- format_textual_memory_item (item ) for item , score in sorted_results [: search_req .top_k ]
201+ format_textual_memory_item (item ) for item in final_results [: search_req .top_k ]
177202 ]
178203
179204 return formatted_memories
180205
181206 def update_search_memories_to_redis (
182207 self ,
183- user_id : str ,
184- mem_cube_id : str ,
185208 messages : list [ScheduleMessageItem ],
186209 ):
187- mem_cube = messages [ 0 ]. mem_cube
210+ mem_cube : NaiveMemCube = self . current_mem_cube
188211
189212 for msg in messages :
190213 content_dict = json .loads (msg .content )
191214 search_req = content_dict ["search_req" ]
192215 user_context = content_dict ["user_context" ]
193216
194- fine_memories : list [TextualMemoryItem ] = self .search_memories (
217+ memories : list [TextualMemoryItem ] = self .search_memories (
195218 search_req = APISearchRequest (** content_dict ["search_req" ]),
196219 user_context = UserContext (** content_dict ["user_context" ]),
197220 mem_cube = mem_cube ,
198- mode = SearchMode .FINE ,
221+ mode = SearchMode .FAST ,
199222 )
200- formatted_memories = [format_textual_memory_item (data ) for data in fine_memories ]
223+ formatted_memories = [format_textual_memory_item (data ) for data in memories ]
201224
202225 # Sync search data to Redis
203226 self .api_module .sync_search_data (
204227 item_id = msg .item_id ,
205228 user_id = search_req ["user_id" ],
206229 mem_cube_id = user_context ["mem_cube_id" ],
207230 query = search_req ["query" ],
208- memories = fine_memories ,
231+ memories = memories ,
209232 formatted_memories = formatted_memories ,
210233 )
211234
@@ -228,9 +251,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem])
228251 messages = grouped_messages [user_id ][mem_cube_id ]
229252 if len (messages ) == 0 :
230253 return
231- self .update_search_memories_to_redis (
232- user_id = user_id , mem_cube_id = mem_cube_id , messages = messages
233- )
254+ self .update_search_memories_to_redis (messages = messages )
234255
235256 def replace_working_memory (
236257 self ,
0 commit comments