Skip to content

Commit aabad8d

Browse files
committed
feat: Add conversation_turn tracking for session-based memory search
- Add conversation_turn field to APIMemoryHistoryEntryItem schema with default value 0 - Implement session counter in OptimizedScheduler to track turn count per session_id - Update sync_search_data method to accept and store conversation_turn parameter - Maintain session history with LRU eviction (max 5 sessions) - Rename conversation_id to session_id for consistency with request object - Enable direct access to session_id from search requests This feature allows tracking conversation turns within the same session, providing better context for memory retrieval and search history management.
1 parent 8c8d672 commit aabad8d

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

src/memos/mem_scheduler/general_modules/api_misc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
APISearchHistoryManager,
99
TaskRunningStatus,
1010
)
11-
from memos.mem_scheduler.utils.db_utils import get_utc_now
1211
from memos.memories.textual.item import TextualMemoryItem
1312

1413

@@ -45,7 +44,8 @@ def sync_search_data(
4544
query: str,
4645
memories: list[TextualMemoryItem],
4746
formatted_memories: Any,
48-
conversation_id: str | None = None,
47+
session_id: str | None = None,
48+
conversation_turn: int = 0,
4949
) -> Any:
5050
logger.info(
5151
f"Syncing search data for item_id: {item_id}, user_id: {user_id}, mem_cube_id: {mem_cube_id}"
@@ -66,7 +66,7 @@ def sync_search_data(
6666
query=query,
6767
formatted_memories=formatted_memories,
6868
task_status=TaskRunningStatus.COMPLETED, # Use the provided running_status
69-
conversation_id=conversation_id,
69+
session_id=session_id,
7070
memories=memories,
7171
)
7272

@@ -76,18 +76,18 @@ def sync_search_data(
7676
logger.warning(f"Failed to update entry with item_id: {item_id}")
7777
else:
7878
# Add new entry based on running_status
79-
search_entry = APIMemoryHistoryEntryItem(
79+
entry_item = APIMemoryHistoryEntryItem(
8080
item_id=item_id,
8181
query=query,
8282
formatted_memories=formatted_memories,
8383
memories=memories,
8484
task_status=TaskRunningStatus.COMPLETED,
85-
conversation_id=conversation_id,
86-
created_time=get_utc_now(),
85+
session_id=session_id,
86+
conversation_turn=conversation_turn,
8787
)
8888

8989
# Add directly to completed list as APIMemoryHistoryEntryItem instance
90-
search_history.completed_entries.append(search_entry)
90+
search_history.completed_entries.append(entry_item)
9191

9292
# Maintain window size
9393
if len(search_history.completed_entries) > search_history.window_size:

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33

4+
from collections import OrderedDict
45
from typing import TYPE_CHECKING
56

67
from memos.api.product_models import APISearchRequest
@@ -39,6 +40,8 @@ def __init__(self, config: GeneralSchedulerConfig):
3940
super().__init__(config)
4041
self.window_size = int(os.getenv("API_SEARCH_WINDOW_SIZE", 5))
4142
self.history_memory_turns = int(os.getenv("API_SEARCH_HISTORY_TURNS", 5))
43+
self.session_counter = OrderedDict()
44+
self.max_session_history = 5
4245

4346
self.api_module = SchedulerAPIModule(
4447
window_size=self.window_size,
@@ -54,13 +57,14 @@ def submit_memory_history_async_task(
5457
self,
5558
search_req: APISearchRequest,
5659
user_context: UserContext,
60+
session_id: str | None = None,
5761
):
5862
# Create message for async fine search
5963
message_content = {
6064
"search_req": {
6165
"query": search_req.query,
6266
"user_id": search_req.user_id,
63-
"session_id": search_req.session_id,
67+
"session_id": session_id,
6468
"top_k": search_req.top_k,
6569
"internet_search": search_req.internet_search,
6670
"moscube": search_req.moscube,
@@ -163,6 +167,7 @@ def mix_search_memories(
163167
self.submit_memory_history_async_task(
164168
search_req=search_req,
165169
user_context=user_context,
170+
session_id=search_req.session_id,
166171
)
167172

168173
# Try to get pre-computed fine memories if available
@@ -171,6 +176,7 @@ def mix_search_memories(
171176
mem_cube_id=user_context.mem_cube_id,
172177
turns=self.history_memory_turns,
173178
)
179+
174180
if not history_memories:
175181
fast_memories = searcher.post_retrieve(
176182
retrieved_results=fast_retrieved_memories,
@@ -214,6 +220,23 @@ def update_search_memories_to_redis(
214220
search_req = content_dict["search_req"]
215221
user_context = content_dict["user_context"]
216222

223+
session_id = search_req.get("session_id")
224+
if session_id:
225+
if session_id not in self.session_counter:
226+
self.session_counter[session_id] = 0
227+
else:
228+
self.session_counter[session_id] += 1
229+
session_turn = self.session_counter[session_id]
230+
231+
# Move the current session to the end to mark it as recently used
232+
self.session_counter.move_to_end(session_id)
233+
234+
# If the counter exceeds the max size, remove the oldest item
235+
if len(self.session_counter) > self.max_session_history:
236+
self.session_counter.popitem(last=False)
237+
else:
238+
session_turn = 0
239+
217240
memories: list[TextualMemoryItem] = self.search_memories(
218241
search_req=APISearchRequest(**content_dict["search_req"]),
219242
user_context=UserContext(**content_dict["user_context"]),
@@ -230,6 +253,8 @@ def update_search_memories_to_redis(
230253
query=search_req["query"],
231254
memories=memories,
232255
formatted_memories=formatted_memories,
256+
session_id=session_id,
257+
conversation_turn=session_turn,
233258
)
234259

235260
def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:

src/memos/mem_scheduler/schemas/api_schemas.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ class APIMemoryHistoryEntryItem(BaseModel, DictConversionMixin):
3535
task_status: str = Field(
3636
default="running", description="Task status: running, completed, failed"
3737
)
38-
conversation_id: str | None = Field(
39-
default=None, description="Optional conversation identifier"
40-
)
38+
session_id: str | None = Field(default=None, description="Optional conversation identifier")
4139
created_time: datetime = Field(description="Entry creation time", default_factory=get_utc_now)
4240
timestamp: datetime | None = Field(default=None, description="Timestamp for the entry")
41+
conversation_turn: int = Field(default=0, description="Turn count for the same session_id")
4342

4443
model_config = ConfigDict(
4544
arbitrary_types_allowed=True,
@@ -107,11 +106,13 @@ def get_running_item_ids(self) -> list[str]:
107106
"""Get all running task IDs"""
108107
return self.running_item_ids.copy()
109108

110-
def get_completed_entries(self) -> list[dict[str, Any]]:
109+
def get_completed_entries(self) -> list[APIMemoryHistoryEntryItem]:
111110
"""Get all completed entries"""
112111
return self.completed_entries.copy()
113112

114-
def get_history_memory_entries(self, turns: int | None = None) -> list[dict[str, Any]]:
113+
def get_history_memory_entries(
114+
self, turns: int | None = None
115+
) -> list[APIMemoryHistoryEntryItem]:
115116
"""
116117
Get the most recent n completed search entries, sorted by created_time.
117118
@@ -179,7 +180,7 @@ def update_entry_by_item_id(
179180
query: str,
180181
formatted_memories: Any,
181182
task_status: TaskRunningStatus,
182-
conversation_id: str | None = None,
183+
session_id: str | None = None,
183184
memories: list[TextualMemoryItem] | None = None,
184185
) -> bool:
185186
"""
@@ -191,7 +192,7 @@ def update_entry_by_item_id(
191192
query: New query string
192193
formatted_memories: New formatted memories
193194
task_status: New task status
194-
conversation_id: New conversation ID
195+
session_id: New conversation ID
195196
memories: List of TextualMemoryItem objects
196197
197198
Returns:
@@ -204,8 +205,8 @@ def update_entry_by_item_id(
204205
entry.query = query
205206
entry.formatted_memories = formatted_memories
206207
entry.task_status = task_status
207-
if conversation_id is not None:
208-
entry.conversation_id = conversation_id
208+
if session_id is not None:
209+
entry.session_id = session_id
209210
if memories is not None:
210211
entry.memories = memories
211212

0 commit comments

Comments
 (0)