Skip to content

Commit 1ec793d

Browse files
authored
fix: session manager tracks all agent last message (#455)
1 parent 730f01e commit 1ec793d

File tree

9 files changed

+124
-133
lines changed

9 files changed

+124
-133
lines changed

src/strands/session/file_session_manager.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import shutil
77
import tempfile
8-
from dataclasses import asdict
98
from typing import Any, Optional, cast
109

1110
from ..types.exceptions import SessionException
@@ -57,22 +56,18 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
5756
session_path = self._get_session_path(session_id)
5857
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
5958

60-
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
59+
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
6160
"""Get message file path.
6261
6362
Args:
6463
session_id: ID of the session
6564
agent_id: ID of the agent
66-
message_id: ID of the message
67-
timestamp: ISO format timestamp to include in filename for sorting
65+
message_id: Index of the message
6866
Returns:
6967
The filename for the message
7068
"""
7169
agent_path = self._get_agent_path(session_id, agent_id)
72-
# Use timestamp for sortable filenames
73-
# Replace colons and periods in ISO format with underscores for filesystem compatibility
74-
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
75-
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json")
70+
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
7671

7772
def _read_file(self, path: str) -> dict[str, Any]:
7873
"""Read JSON file."""
@@ -100,7 +95,7 @@ def create_session(self, session: Session) -> Session:
10095

10196
# Write session file
10297
session_file = os.path.join(session_dir, "session.json")
103-
session_dict = asdict(session)
98+
session_dict = session.to_dict()
10499
self._write_file(session_file, session_dict)
105100

106101
return session
@@ -123,7 +118,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
123118
os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True)
124119

125120
agent_file = os.path.join(agent_dir, "agent.json")
126-
session_data = asdict(session_agent)
121+
session_data = session_agent.to_dict()
127122
self._write_file(agent_file, session_data)
128123

129124
def delete_session(self, session_id: str) -> None:
@@ -152,34 +147,25 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
152147

153148
session_agent.created_at = previous_agent.created_at
154149
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
155-
self._write_file(agent_file, asdict(session_agent))
150+
self._write_file(agent_file, session_agent.to_dict())
156151

157152
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
158153
"""Create a new message for the agent."""
159154
message_file = self._get_message_path(
160155
session_id,
161156
agent_id,
162157
session_message.message_id,
163-
session_message.created_at,
164158
)
165-
session_dict = asdict(session_message)
159+
session_dict = session_message.to_dict()
166160
self._write_file(message_file, session_dict)
167161

168-
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
162+
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
169163
"""Read message data."""
170-
# Get the messages directory
171-
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
172-
if not os.path.exists(messages_dir):
164+
message_path = self._get_message_path(session_id, agent_id, message_id)
165+
if not os.path.exists(message_path):
173166
return None
174-
175-
# List files in messages directory, and check if the filename ends with the message id
176-
for filename in os.listdir(messages_dir):
177-
if filename.endswith(f"{message_id}.json"):
178-
file_path = os.path.join(messages_dir, filename)
179-
message_data = self._read_file(file_path)
180-
return SessionMessage.from_dict(message_data)
181-
182-
return None
167+
message_data = self._read_file(message_path)
168+
return SessionMessage.from_dict(message_data)
183169

184170
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
185171
"""Update message data."""
@@ -190,8 +176,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
190176

191177
# Preserve the original created_at timestamp
192178
session_message.created_at = previous_message.created_at
193-
message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
194-
self._write_file(message_file, asdict(session_message))
179+
message_file = self._get_message_path(session_id, agent_id, message_id)
180+
self._write_file(message_file, session_message.to_dict())
195181

196182
def list_messages(
197183
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
@@ -201,14 +187,16 @@ def list_messages(
201187
if not os.path.exists(messages_dir):
202188
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")
203189

204-
# Read all message files
205-
message_files: list[str] = []
190+
# Read all message files, and record the index
191+
message_index_files: list[tuple[int, str]] = []
206192
for filename in os.listdir(messages_dir):
207193
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
208-
message_files.append(filename)
194+
# Extract index from message_<index>.json format
195+
index = int(filename[len(MESSAGE_PREFIX) : -5]) # Remove prefix and .json suffix
196+
message_index_files.append((index, filename))
209197

210-
# Sort filenames - the timestamp in the file's name will sort chronologically
211-
message_files.sort()
198+
# Sort by index and extract just the filenames
199+
message_files = [f for _, f in sorted(message_index_files)]
212200

213201
# Apply pagination to filenames
214202
if limit is not None:

src/strands/session/repository_session_manager.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,8 @@ def __init__(
4848

4949
self.session = session
5050

51-
# Keep track of the initialized agent id's so that two agents in a session cannot share an id
52-
self._initialized_agent_ids: set[str] = set()
53-
54-
# Keep track of the latest message stored in the session in case we need to redact its content.
55-
self._latest_message: Optional[SessionMessage] = None
51+
# Keep track of the latest message of each agent in case we need to redact it.
52+
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
5653

5754
def append_message(self, message: Message, agent: Agent) -> None:
5855
"""Append a message to the agent's session.
@@ -61,8 +58,16 @@ def append_message(self, message: Message, agent: Agent) -> None:
6158
message: Message to add to the agent in the session
6259
agent: Agent to append the message to
6360
"""
64-
self._latest_message = SessionMessage.from_message(message)
65-
self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message)
61+
# Calculate the next index (0 if this is the first message, otherwise increment the previous index)
62+
latest_agent_message = self._latest_agent_message[agent.agent_id]
63+
if latest_agent_message:
64+
next_index = latest_agent_message.message_id + 1
65+
else:
66+
next_index = 0
67+
68+
session_message = SessionMessage.from_message(message, next_index)
69+
self._latest_agent_message[agent.agent_id] = session_message
70+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
6671

6772
def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
6873
"""Redact the latest message appended to the session.
@@ -71,10 +76,11 @@ def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
7176
redact_message: New message to use that contains the redact content
7277
agent: Agent to apply the message redaction to
7378
"""
74-
if self._latest_message is None:
79+
latest_agent_message = self._latest_agent_message[agent.agent_id]
80+
if latest_agent_message is None:
7581
raise SessionException("No message to redact.")
76-
self._latest_message.redact_message = redact_message
77-
return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message)
82+
latest_agent_message.redact_message = redact_message
83+
return self.session_repository.update_message(self.session_id, agent.agent_id, latest_agent_message)
7884

7985
def sync_agent(self, agent: Agent) -> None:
8086
"""Serialize and update the agent into the session repository.
@@ -93,9 +99,9 @@ def initialize(self, agent: Agent) -> None:
9399
Args:
94100
agent: Agent to initialize from the session
95101
"""
96-
if agent.agent_id in self._initialized_agent_ids:
102+
if agent.agent_id in self._latest_agent_message:
97103
raise SessionException("The `agent_id` of an agent must be unique in a session.")
98-
self._initialized_agent_ids.add(agent.agent_id)
104+
self._latest_agent_message[agent.agent_id] = None
99105

100106
session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)
101107

@@ -108,8 +114,9 @@ def initialize(self, agent: Agent) -> None:
108114

109115
session_agent = SessionAgent.from_agent(agent)
110116
self.session_repository.create_agent(self.session_id, session_agent)
111-
for message in agent.messages:
112-
session_message = SessionMessage.from_message(message)
117+
# Initialize messages with sequential indices
118+
for i, message in enumerate(agent.messages):
119+
session_message = SessionMessage.from_message(message, i)
113120
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
114121
else:
115122
logger.debug(

src/strands/session/s3_session_manager.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import json
44
import logging
5-
from dataclasses import asdict
65
from typing import Any, Dict, List, Optional, cast
76

87
import boto3
@@ -85,22 +84,18 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
8584
session_path = self._get_session_path(session_id)
8685
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"
8786

88-
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
87+
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
8988
"""Get message S3 key.
9089
9190
Args:
9291
session_id: ID of the session
9392
agent_id: ID of the agent
94-
message_id: ID of the message
95-
timestamp: ISO format timestamp to include in key for sorting
93+
message_id: Index of the message
9694
Returns:
9795
The key for the message
9896
"""
9997
agent_path = self._get_agent_path(session_id, agent_id)
100-
# Use timestamp for sortable keys
101-
# Replace colons and periods in ISO format with underscores for filesystem compatibility
102-
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
103-
return f"{agent_path}messages/{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json"
98+
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"
10499

105100
def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]:
106101
"""Read JSON object from S3."""
@@ -139,7 +134,7 @@ def create_session(self, session: Session) -> Session:
139134
raise SessionException(f"S3 error checking session existence: {e}") from e
140135

141136
# Write session object
142-
session_dict = asdict(session)
137+
session_dict = session.to_dict()
143138
self._write_s3_object(session_key, session_dict)
144139
return session
145140

@@ -177,7 +172,7 @@ def delete_session(self, session_id: str) -> None:
177172
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
178173
"""Create a new agent in S3."""
179174
agent_id = session_agent.agent_id
180-
agent_dict = asdict(session_agent)
175+
agent_dict = session_agent.to_dict()
181176
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
182177
self._write_s3_object(agent_key, agent_dict)
183178

@@ -199,35 +194,22 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
199194
# Preserve creation timestamp
200195
session_agent.created_at = previous_agent.created_at
201196
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
202-
self._write_s3_object(agent_key, asdict(session_agent))
197+
self._write_s3_object(agent_key, session_agent.to_dict())
203198

204199
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
205200
"""Create a new message in S3."""
206201
message_id = session_message.message_id
207-
message_dict = asdict(session_message)
208-
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
202+
message_dict = session_message.to_dict()
203+
message_key = self._get_message_path(session_id, agent_id, message_id)
209204
self._write_s3_object(message_key, message_dict)
210205

211-
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
206+
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
212207
"""Read message data from S3."""
213-
# Get the messages prefix
214-
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
215-
try:
216-
paginator = self.client.get_paginator("list_objects_v2")
217-
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)
218-
219-
for page in pages:
220-
if "Contents" in page:
221-
for obj in page["Contents"]:
222-
if obj["Key"].endswith(f"{message_id}.json"):
223-
message_data = self._read_s3_object(obj["Key"])
224-
if message_data:
225-
return SessionMessage.from_dict(message_data)
226-
208+
message_key = self._get_message_path(session_id, agent_id, message_id)
209+
message_data = self._read_s3_object(message_key)
210+
if message_data is None:
227211
return None
228-
229-
except ClientError as e:
230-
raise SessionException(f"S3 error reading message: {e}") from e
212+
return SessionMessage.from_dict(message_data)
231213

232214
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
233215
"""Update message data in S3."""
@@ -238,8 +220,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
238220

239221
# Preserve creation timestamp
240222
session_message.created_at = previous_message.created_at
241-
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
242-
self._write_s3_object(message_key, asdict(session_message))
223+
message_key = self._get_message_path(session_id, agent_id, message_id)
224+
self._write_s3_object(message_key, session_message.to_dict())
243225

244226
def list_messages(
245227
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
@@ -250,16 +232,21 @@ def list_messages(
250232
paginator = self.client.get_paginator("list_objects_v2")
251233
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)
252234

253-
# Collect all message keys first
254-
message_keys = []
235+
# Collect all message keys and extract their indices
236+
message_index_keys: list[tuple[int, str]] = []
255237
for page in pages:
256238
if "Contents" in page:
257239
for obj in page["Contents"]:
258-
if obj["Key"].endswith(".json") and MESSAGE_PREFIX in obj["Key"]:
259-
message_keys.append(obj["Key"])
260-
261-
# Sort keys - timestamp prefixed keys will sort chronologically
262-
message_keys.sort()
240+
key = obj["Key"]
241+
if key.endswith(".json") and MESSAGE_PREFIX in key:
242+
# Extract the filename part from the full S3 key
243+
filename = key.split("/")[-1]
244+
# Extract index from message_<index>.json format
245+
index = int(filename[len(MESSAGE_PREFIX) : -5]) # Remove prefix and .json suffix
246+
message_index_keys.append((index, key))
247+
248+
# Sort by index and extract just the keys
249+
message_keys = [k for _, k in sorted(message_index_keys)]
263250

264251
# Apply pagination to keys before loading content
265252
if limit is not None:

src/strands/session/session_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
3434
"""Create a new Message for the Agent."""
3535

3636
@abstractmethod
37-
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
37+
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
3838
"""Read a Message."""
3939

4040
@abstractmethod

0 commit comments

Comments
 (0)