Skip to content

Commit 30fb562

Browse files
authored
Session persistence (#302)
* feat: Session persistence * refactor: add pr feedback
1 parent 5dc3f59 commit 30fb562

16 files changed

+2033
-6
lines changed

src/strands/agent/agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
18-
from uuid import uuid4
1918

2019
from opentelemetry import trace
2120
from pydantic import BaseModel
@@ -32,6 +31,7 @@
3231
)
3332
from ..models.bedrock import BedrockModel
3433
from ..models.model import Model
34+
from ..session.session_manager import SessionManager
3535
from ..telemetry.metrics import EventLoopMetrics
3636
from ..telemetry.tracer import get_tracer
3737
from ..tools.registry import ToolRegistry
@@ -62,6 +62,7 @@ class _DefaultCallbackHandlerSentinel:
6262

6363
_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel()
6464
_DEFAULT_AGENT_NAME = "Strands Agents"
65+
_DEFAULT_AGENT_ID = "default"
6566

6667

6768
class Agent:
@@ -207,6 +208,7 @@ def __init__(
207208
description: Optional[str] = None,
208209
state: Optional[Union[AgentState, dict]] = None,
209210
hooks: Optional[list[HookProvider]] = None,
211+
session_manager: Optional[SessionManager] = None,
210212
):
211213
"""Initialize the Agent with the specified configuration.
212214
@@ -237,22 +239,24 @@ def __init__(
237239
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
238240
Defaults to False.
239241
trace_attributes: Custom trace attributes to apply to the agent's trace span.
240-
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
241-
If None, a UUID is generated.
242+
agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios.
243+
Defaults to "default".
242244
name: name of the Agent
243-
Defaults to None.
245+
Defaults to "Strands Agents".
244246
description: description of what the Agent does
245247
Defaults to None.
246248
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
247249
Defaults to an empty AgentState object.
248250
hooks: hooks to be added to the agent hook registry
249251
Defaults to None.
252+
session_manager: Manager for handling agent sessions including conversation history and state.
253+
If provided, enables session-based persistence and state management.
250254
"""
251255
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
252256
self.messages = messages if messages is not None else []
253257

254258
self.system_prompt = system_prompt
255-
self.agent_id = agent_id or str(uuid4())
259+
self.agent_id = agent_id or _DEFAULT_AGENT_ID
256260
self.name = name or _DEFAULT_AGENT_NAME
257261
self.description = description
258262

@@ -312,6 +316,12 @@ def __init__(
312316
self.tool_caller = Agent.ToolCaller(self)
313317

314318
self.hooks = HookRegistry()
319+
320+
# Initialize session management functionality
321+
self._session_manager = session_manager
322+
if self._session_manager:
323+
self.hooks.add_hook(self._session_manager)
324+
315325
if hooks:
316326
for hook in hooks:
317327
self.hooks.add_hook(hook)
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""File-based session manager for local filesystem storage."""
2+
3+
import json
4+
import logging
5+
import os
6+
import shutil
7+
import tempfile
8+
from dataclasses import asdict
9+
from typing import Any, Optional, cast
10+
11+
from ..types.exceptions import SessionException
12+
from ..types.session import Session, SessionAgent, SessionMessage
13+
from .repository_session_manager import RepositorySessionManager
14+
from .session_repository import SessionRepository
15+
16+
logger = logging.getLogger(__name__)
17+
18+
SESSION_PREFIX = "session_"
19+
AGENT_PREFIX = "agent_"
20+
MESSAGE_PREFIX = "message_"
21+
22+
23+
class FileSessionManager(RepositorySessionManager, SessionRepository):
24+
"""File-based session manager for local filesystem storage.
25+
26+
Creates the following filesystem structure for the session storage:
27+
/<sessions_dir>/
28+
└── session_<session_id>/
29+
├── session.json # Session metadata
30+
└── agents/
31+
└── agent_<agent_id>/
32+
├── agent.json # Agent metadata
33+
└── messages/
34+
├── message_<created_timestamp>_<id1>.json
35+
└── message_<created_timestamp>_<id2>.json
36+
37+
"""
38+
39+
def __init__(self, session_id: str, storage_dir: Optional[str] = None):
40+
"""Initialize FileSession with filesystem storage.
41+
42+
Args:
43+
session_id: ID for the session
44+
storage_dir: Directory for local filesystem storage (defaults to temp dir)
45+
"""
46+
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
47+
os.makedirs(self.storage_dir, exist_ok=True)
48+
49+
super().__init__(session_id=session_id, session_repository=self)
50+
51+
def _get_session_path(self, session_id: str) -> str:
52+
"""Get session directory path."""
53+
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")
54+
55+
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
56+
"""Get agent directory path."""
57+
session_path = self._get_session_path(session_id)
58+
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
59+
60+
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
61+
"""Get message file path.
62+
63+
Args:
64+
session_id: ID of the session
65+
agent_id: ID of the agent
66+
message_id: ID of the message
67+
timestamp: ISO format timestamp to include in filename for sorting
68+
Returns:
69+
The filename for the message
70+
"""
71+
agent_path = self._get_agent_path(session_id, agent_id)
72+
# Use timestamp for sortable filenames
73+
# Replace colons and periods in ISO format with underscores for filesystem compatibility
74+
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
75+
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json")
76+
77+
def _read_file(self, path: str) -> dict[str, Any]:
78+
"""Read JSON file."""
79+
try:
80+
with open(path, "r", encoding="utf-8") as f:
81+
return cast(dict[str, Any], json.load(f))
82+
except json.JSONDecodeError as e:
83+
raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e
84+
85+
def _write_file(self, path: str, data: dict[str, Any]) -> None:
86+
"""Write JSON file."""
87+
os.makedirs(os.path.dirname(path), exist_ok=True)
88+
with open(path, "w", encoding="utf-8") as f:
89+
json.dump(data, f, indent=2, ensure_ascii=False)
90+
91+
def create_session(self, session: Session) -> Session:
92+
"""Create a new session."""
93+
session_dir = self._get_session_path(session.session_id)
94+
if os.path.exists(session_dir):
95+
raise SessionException(f"Session {session.session_id} already exists")
96+
97+
# Create directory structure
98+
os.makedirs(session_dir, exist_ok=True)
99+
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
100+
101+
# Write session file
102+
session_file = os.path.join(session_dir, "session.json")
103+
session_dict = asdict(session)
104+
self._write_file(session_file, session_dict)
105+
106+
return session
107+
108+
def read_session(self, session_id: str) -> Optional[Session]:
109+
"""Read session data."""
110+
session_file = os.path.join(self._get_session_path(session_id), "session.json")
111+
if not os.path.exists(session_file):
112+
return None
113+
114+
session_data = self._read_file(session_file)
115+
return Session.from_dict(session_data)
116+
117+
def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
118+
"""Create a new agent in the session."""
119+
agent_id = session_agent.agent_id
120+
121+
agent_dir = self._get_agent_path(session_id, agent_id)
122+
os.makedirs(agent_dir, exist_ok=True)
123+
os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True)
124+
125+
agent_file = os.path.join(agent_dir, "agent.json")
126+
session_data = asdict(session_agent)
127+
self._write_file(agent_file, session_data)
128+
129+
def delete_session(self, session_id: str) -> None:
130+
"""Delete session and all associated data."""
131+
session_dir = self._get_session_path(session_id)
132+
if not os.path.exists(session_dir):
133+
raise SessionException(f"Session {session_id} does not exist")
134+
135+
shutil.rmtree(session_dir)
136+
137+
def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
138+
"""Read agent data."""
139+
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
140+
if not os.path.exists(agent_file):
141+
return None
142+
143+
agent_data = self._read_file(agent_file)
144+
return SessionAgent.from_dict(agent_data)
145+
146+
def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
147+
"""Update agent data."""
148+
agent_id = session_agent.agent_id
149+
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
150+
if previous_agent is None:
151+
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
152+
153+
session_agent.created_at = previous_agent.created_at
154+
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
155+
self._write_file(agent_file, asdict(session_agent))
156+
157+
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
158+
"""Create a new message for the agent."""
159+
message_file = self._get_message_path(
160+
session_id,
161+
agent_id,
162+
session_message.message_id,
163+
session_message.created_at,
164+
)
165+
session_dict = asdict(session_message)
166+
self._write_file(message_file, session_dict)
167+
168+
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
169+
"""Read message data."""
170+
# Get the messages directory
171+
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
172+
if not os.path.exists(messages_dir):
173+
return None
174+
175+
# List files in messages directory, and check if the filename ends with the message id
176+
for filename in os.listdir(messages_dir):
177+
if filename.endswith(f"{message_id}.json"):
178+
file_path = os.path.join(messages_dir, filename)
179+
message_data = self._read_file(file_path)
180+
return SessionMessage.from_dict(message_data)
181+
182+
return None
183+
184+
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
185+
"""Update message data."""
186+
message_id = session_message.message_id
187+
previous_message = self.read_message(session_id=session_id, agent_id=agent_id, message_id=message_id)
188+
if previous_message is None:
189+
raise SessionException(f"Message {message_id} does not exist")
190+
191+
# Preserve the original created_at timestamp
192+
session_message.created_at = previous_message.created_at
193+
message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
194+
self._write_file(message_file, asdict(session_message))
195+
196+
def list_messages(
197+
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
198+
) -> list[SessionMessage]:
199+
"""List messages for an agent with pagination."""
200+
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
201+
if not os.path.exists(messages_dir):
202+
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")
203+
204+
# Read all message files
205+
message_files: list[str] = []
206+
for filename in os.listdir(messages_dir):
207+
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
208+
message_files.append(filename)
209+
210+
# Sort filenames - the timestamp in the file's name will sort chronologically
211+
message_files.sort()
212+
213+
# Apply pagination to filenames
214+
if limit is not None:
215+
message_files = message_files[offset : offset + limit]
216+
else:
217+
message_files = message_files[offset:]
218+
219+
# Load only the message files
220+
messages: list[SessionMessage] = []
221+
for filename in message_files:
222+
file_path = os.path.join(messages_dir, filename)
223+
message_data = self._read_file(file_path)
224+
messages.append(SessionMessage.from_dict(message_data))
225+
226+
return messages

0 commit comments

Comments
 (0)