Skip to content

Commit 8c8d672

Browse files
committed
feat: Optimize mixture search and enhance API client
1 parent 011d248 commit 8c8d672

File tree

5 files changed

+204
-119
lines changed

5 files changed

+204
-119
lines changed

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Callable
77
from datetime import datetime
88
from pathlib import Path
9+
from typing import TYPE_CHECKING
910

1011
from sqlalchemy.engine import Engine
1112

@@ -50,6 +51,10 @@
5051
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
5152

5253

54+
if TYPE_CHECKING:
55+
from memos.mem_cube.base import BaseMemCube
56+
57+
5358
logger = get_logger(__name__)
5459

5560

@@ -124,7 +129,7 @@ def __init__(self, config: BaseSchedulerConfig):
124129
self._context_lock = threading.Lock()
125130
self.current_user_id: UserID | str | None = None
126131
self.current_mem_cube_id: MemCubeID | str | None = None
127-
self.current_mem_cube: GeneralMemCube | None = None
132+
self.current_mem_cube: BaseMemCube | None = None
128133
self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None)
129134
self.auth_config = None
130135
self.rabbitmq_config = None

src/memos/mem_scheduler/general_modules/api_misc.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616

1717

1818
class SchedulerAPIModule(BaseSchedulerModule):
19-
def __init__(self, window_size=5):
19+
def __init__(self, window_size: int | None = None, history_memory_turns: int | None = None):
2020
super().__init__()
2121
self.window_size = window_size
22+
self.history_memory_turns = history_memory_turns
2223
self.search_history_managers: dict[str, APIRedisDBManager] = {}
23-
self.pre_memory_turns = 5
2424

2525
def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> APIRedisDBManager:
2626
"""Get or create a Redis manager for search history."""
27+
logger.info(
28+
f"Getting search history manager for user_id: {user_id}, mem_cube_id: {mem_cube_id}"
29+
)
2730
key = f"search_history:{user_id}:{mem_cube_id}"
2831
if key not in self.search_history_managers:
32+
logger.info(f"Creating new search history manager for key: {key}")
2933
self.search_history_managers[key] = APIRedisDBManager(
3034
user_id=user_id,
3135
mem_cube_id=mem_cube_id,
@@ -43,6 +47,9 @@ def sync_search_data(
4347
formatted_memories: Any,
4448
conversation_id: str | None = None,
4549
) -> Any:
50+
logger.info(
51+
f"Syncing search data for item_id: {item_id}, user_id: {user_id}, mem_cube_id: {mem_cube_id}"
52+
)
4653
# Get the search history manager
4754
manager = self.get_search_history_manager(user_id, mem_cube_id)
4855
manager.sync_with_redis(size_limit=self.window_size)
@@ -101,37 +108,22 @@ def sync_search_data(
101108
manager.sync_with_redis(size_limit=self.window_size)
102109
return manager
103110

104-
def get_pre_memories(self, user_id: str, mem_cube_id: str) -> list:
105-
"""
106-
Get pre-computed memories from the most recent completed search entry.
107-
108-
Args:
109-
user_id: User identifier
110-
mem_cube_id: Memory cube identifier
111-
112-
Returns:
113-
List of TextualMemoryItem objects from the most recent completed search
114-
"""
115-
manager = self.get_search_history_manager(user_id, mem_cube_id)
116-
117-
existing_data = manager.load_from_db()
118-
if existing_data is None:
119-
return []
120-
121-
search_history: APISearchHistoryManager = existing_data
122-
123-
# Get memories from the most recent completed entry
124-
history_memories = search_history.get_history_memories(turns=self.pre_memory_turns)
125-
return history_memories
126-
127-
def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list:
111+
def get_history_memories(
112+
self, user_id: str, mem_cube_id: str, turns: int | None = None
113+
) -> list:
128114
"""Get history memories for backward compatibility with tests."""
115+
logger.info(
116+
f"Getting history memories for user_id: {user_id}, mem_cube_id: {mem_cube_id}, turns: {turns}"
117+
)
129118
manager = self.get_search_history_manager(user_id, mem_cube_id)
130119
existing_data = manager.load_from_db()
131120

132121
if existing_data is None:
133122
return []
134123

124+
if turns is None:
125+
turns = self.history_memory_turns
126+
135127
# Handle different data formats
136128
if isinstance(existing_data, APISearchHistoryManager):
137129
search_history = existing_data
@@ -142,4 +134,4 @@ def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list:
142134
except Exception:
143135
return []
144136

145-
return search_history.get_history_memories(turns=n)
137+
return search_history.get_history_memories(turns=turns)

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 94 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
2+
import os
23

34
from typing import TYPE_CHECKING
45

56
from memos.api.product_models import APISearchRequest
67
from memos.configs.mem_scheduler import GeneralSchedulerConfig
78
from memos.log import get_logger
89
from memos.mem_cube.general import GeneralMemCube
10+
from memos.mem_cube.navie import NaiveMemCube
911
from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule
1012
from memos.mem_scheduler.general_scheduler import GeneralScheduler
1113
from memos.mem_scheduler.schemas.general_schemas import (
@@ -23,6 +25,7 @@
2325

2426
if TYPE_CHECKING:
2527
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
28+
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
2629
from memos.reranker.http_bge import HTTPBGEReranker
2730

2831

@@ -34,43 +37,19 @@ class OptimizedScheduler(GeneralScheduler):
3437

3538
def __init__(self, config: GeneralSchedulerConfig):
3639
super().__init__(config)
37-
self.api_module = SchedulerAPIModule()
40+
self.window_size = int(os.getenv("API_SEARCH_WINDOW_SIZE", 5))
41+
self.history_memory_turns = int(os.getenv("API_SEARCH_HISTORY_TURNS", 5))
42+
43+
self.api_module = SchedulerAPIModule(
44+
window_size=self.window_size,
45+
history_memory_turns=self.history_memory_turns,
46+
)
3847
self.register_handlers(
3948
{
4049
API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer,
4150
}
4251
)
4352

44-
def search_memories(
45-
self,
46-
search_req: APISearchRequest,
47-
user_context: UserContext,
48-
mem_cube: GeneralMemCube,
49-
mode: SearchMode,
50-
):
51-
"""Fine search memories function copied from server_router to avoid circular import"""
52-
target_session_id = search_req.session_id
53-
if not target_session_id:
54-
target_session_id = "default_session"
55-
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
56-
57-
# Create MemCube and perform search
58-
search_results = mem_cube.text_mem.search(
59-
query=search_req.query,
60-
user_name=user_context.mem_cube_id,
61-
top_k=search_req.top_k,
62-
mode=mode,
63-
manual_close_internet=not search_req.internet_search,
64-
moscube=search_req.moscube,
65-
search_filter=search_filter,
66-
info={
67-
"user_id": search_req.user_id,
68-
"session_id": target_session_id,
69-
"chat_history": search_req.chat_history,
70-
},
71-
)
72-
return search_results
73-
7453
def submit_memory_history_async_task(
7554
self,
7655
search_req: APISearchRequest,
@@ -110,6 +89,36 @@ def submit_memory_history_async_task(
11089
logger.info(f"Submitted async fine search task for user {search_req.user_id}")
11190
return async_task_id
11291

92+
def search_memories(
93+
self,
94+
search_req: APISearchRequest,
95+
user_context: UserContext,
96+
mem_cube: NaiveMemCube,
97+
mode: SearchMode,
98+
):
99+
"""Fine search memories function copied from server_router to avoid circular import"""
100+
target_session_id = search_req.session_id
101+
if not target_session_id:
102+
target_session_id = "default_session"
103+
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
104+
105+
# Create MemCube and perform search
106+
search_results = mem_cube.text_mem.search(
107+
query=search_req.query,
108+
user_name=user_context.mem_cube_id,
109+
top_k=search_req.top_k,
110+
mode=mode,
111+
manual_close_internet=not search_req.internet_search,
112+
moscube=search_req.moscube,
113+
search_filter=search_filter,
114+
info={
115+
"user_id": search_req.user_id,
116+
"session_id": target_session_id,
117+
"chat_history": search_req.chat_history,
118+
},
119+
)
120+
return search_results
121+
113122
def mix_search_memories(
114123
self,
115124
search_req: APISearchRequest,
@@ -122,12 +131,33 @@ def mix_search_memories(
122131
# Get mem_cube for fast search
123132
mem_cube = self.current_mem_cube
124133

125-
# Perform fast search
126-
fast_memories = self.search_memories(
127-
search_req=search_req,
128-
user_context=user_context,
129-
mem_cube=mem_cube,
134+
target_session_id = search_req.session_id
135+
if not target_session_id:
136+
target_session_id = "default_session"
137+
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
138+
139+
text_mem: TreeTextMemory = mem_cube.text_mem
140+
searcher: Searcher = text_mem.get_searcher(
141+
manual_close_internet=not search_req.internet_search,
142+
moscube=False,
143+
)
144+
# Rerank Memories - reranker expects TextualMemoryItem objects
145+
reranker: HTTPBGEReranker = text_mem.reranker
146+
info = {
147+
"user_id": search_req.user_id,
148+
"session_id": target_session_id,
149+
"chat_history": search_req.chat_history,
150+
}
151+
152+
fast_retrieved_memories = searcher.retrieve(
153+
query=search_req.query,
154+
user_name=user_context.mem_cube_id,
155+
top_k=search_req.top_k,
130156
mode=SearchMode.FAST,
157+
manual_close_internet=not search_req.internet_search,
158+
moscube=search_req.moscube,
159+
search_filter=search_filter,
160+
info=info,
131161
)
132162

133163
self.submit_memory_history_async_task(
@@ -136,76 +166,69 @@ def mix_search_memories(
136166
)
137167

138168
# Try to get pre-computed fine memories if available
139-
pre_fine_memories = self.api_module.get_pre_memories(
140-
user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id
169+
history_memories = self.api_module.get_history_memories(
170+
user_id=search_req.user_id,
171+
mem_cube_id=user_context.mem_cube_id,
172+
turns=self.history_memory_turns,
141173
)
142-
if not pre_fine_memories:
174+
if not history_memories:
175+
fast_memories = searcher.post_retrieve(
176+
retrieved_results=fast_retrieved_memories,
177+
top_k=search_req.top_k,
178+
user_name=user_context.mem_cube_id,
179+
info=info,
180+
)
143181
# Format fast memories for return
144182
formatted_memories = [format_textual_memory_item(data) for data in fast_memories]
145183
return formatted_memories
146184

147-
# Merge fast and pre-computed fine memories (both are TextualMemoryItem objects)
148-
combined_memories = fast_memories + pre_fine_memories
149-
# Remove duplicates based on memory content
150-
seen_contents = set()
151-
unique_memories = []
152-
for memory in combined_memories:
153-
# Both fast_memories and pre_fine_memories are TextualMemoryItem objects
154-
content_key = memory.memory # Use .memory attribute instead of .get("content", "")
155-
if content_key not in seen_contents:
156-
seen_contents.add(content_key)
157-
unique_memories.append(memory)
158-
159-
# Rerank Memories - reranker expects TextualMemoryItem objects
160-
reranker: HTTPBGEReranker = mem_cube.text_mem.reranker
161-
162-
# Use search_req parameters for reranking
163-
target_session_id = search_req.session_id
164-
if not target_session_id:
165-
target_session_id = "default_session"
166-
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
167-
168-
sorted_results = reranker.rerank(
185+
sorted_history_memories = reranker.rerank(
169186
query=search_req.query, # Use search_req.query instead of undefined query
170-
graph_results=unique_memories, # Pass TextualMemoryItem objects directly
187+
graph_results=history_memories, # Pass TextualMemoryItem objects directly
171188
top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
172189
search_filter=search_filter,
173190
)
174191

192+
sorted_results = fast_retrieved_memories + sorted_history_memories
193+
final_results = searcher.post_retrieve(
194+
retrieved_results=sorted_results,
195+
top_k=search_req.top_k,
196+
user_name=user_context.mem_cube_id,
197+
info=info,
198+
)
199+
175200
formatted_memories = [
176-
format_textual_memory_item(item) for item, score in sorted_results[: search_req.top_k]
201+
format_textual_memory_item(item) for item in final_results[: search_req.top_k]
177202
]
178203

179204
return formatted_memories
180205

181206
def update_search_memories_to_redis(
182207
self,
183-
user_id: str,
184-
mem_cube_id: str,
185208
messages: list[ScheduleMessageItem],
186209
):
187-
mem_cube = messages[0].mem_cube
210+
mem_cube: NaiveMemCube = self.current_mem_cube
188211

189212
for msg in messages:
190213
content_dict = json.loads(msg.content)
191214
search_req = content_dict["search_req"]
192215
user_context = content_dict["user_context"]
193216

194-
fine_memories: list[TextualMemoryItem] = self.search_memories(
217+
memories: list[TextualMemoryItem] = self.search_memories(
195218
search_req=APISearchRequest(**content_dict["search_req"]),
196219
user_context=UserContext(**content_dict["user_context"]),
197220
mem_cube=mem_cube,
198-
mode=SearchMode.FINE,
221+
mode=SearchMode.FAST,
199222
)
200-
formatted_memories = [format_textual_memory_item(data) for data in fine_memories]
223+
formatted_memories = [format_textual_memory_item(data) for data in memories]
201224

202225
# Sync search data to Redis
203226
self.api_module.sync_search_data(
204227
item_id=msg.item_id,
205228
user_id=search_req["user_id"],
206229
mem_cube_id=user_context["mem_cube_id"],
207230
query=search_req["query"],
208-
memories=fine_memories,
231+
memories=memories,
209232
formatted_memories=formatted_memories,
210233
)
211234

@@ -228,9 +251,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem])
228251
messages = grouped_messages[user_id][mem_cube_id]
229252
if len(messages) == 0:
230253
return
231-
self.update_search_memories_to_redis(
232-
user_id=user_id, mem_cube_id=mem_cube_id, messages=messages
233-
)
254+
self.update_search_memories_to_redis(messages=messages)
234255

235256
def replace_working_memory(
236257
self,

0 commit comments

Comments
 (0)