Skip to content

Commit 0c4e3c0

Browse files
committed
feat: add dynamodb session manager
1 parent 7226025 commit 0c4e3c0

File tree

2 files changed

+730
-0
lines changed

2 files changed

+730
-0
lines changed
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
"""DynamoDB-based session manager for cloud storage."""
2+
3+
import logging
4+
from typing import Any, List, Optional
5+
6+
import boto3
7+
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
8+
from botocore.config import Config as BotocoreConfig
9+
from botocore.exceptions import ClientError
10+
11+
from .. import _identifier
12+
from ..types.exceptions import SessionException
13+
from ..types.session import Session, SessionAgent, SessionMessage
14+
from .repository_session_manager import RepositorySessionManager
15+
from .session_repository import SessionRepository
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class DynamoDBSessionManager(RepositorySessionManager, SessionRepository):
21+
"""DynamoDB-based session manager for cloud storage.
22+
23+
Uses a single table design with the following structure:
24+
- PK (HASH): session_<session_id>
25+
- SK (RANGE): session | agent_<agent_id> | agent_<agent_id>#message_<message_id>
26+
27+
Example:
28+
```
29+
┌─────────────────┬──────────────────────────┬─────────────────┬──────────────────┐
30+
│ PK │ SK │ entity_type │ data │
31+
├─────────────────┼──────────────────────────┼─────────────────┼──────────────────┤
32+
│ session_abc123 │ session │ SESSION │ {session_json} │
33+
│ session_abc123 │ agent_agent1 │ AGENT │ {agent_json} │
34+
│ session_abc123 │ agent_agent1#message_0 │ MESSAGE │ {message_json} │
35+
│ session_abc123 │ agent_agent1#message_1 │ MESSAGE │ {message_json} │
36+
└─────────────────┴──────────────────────────┴─────────────────┴──────────────────┘
37+
```
38+
"""
39+
40+
def __init__(
41+
self,
42+
session_id: str,
43+
table_name: str,
44+
boto_session: Optional[boto3.Session] = None,
45+
boto_client_config: Optional[BotocoreConfig] = None,
46+
region_name: Optional[str] = None,
47+
**kwargs: Any,
48+
):
49+
"""Initialize DynamoDBSessionManager.
50+
51+
Args:
52+
session_id: ID for the session
53+
table_name: DynamoDB table name
54+
boto_session: Optional boto3 session
55+
boto_client_config: Optional boto3 client configuration
56+
region_name: AWS region for DynamoDB
57+
**kwargs: Additional keyword arguments for future extensibility.
58+
"""
59+
self.table_name = table_name
60+
61+
session = boto_session or boto3.Session(region_name=region_name)
62+
63+
# Add strands-agents to the request user agent
64+
if boto_client_config:
65+
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
66+
# Append 'strands-agents' to existing user_agent_extra or set it if not present
67+
if existing_user_agent:
68+
new_user_agent = f"{existing_user_agent} strands-agents"
69+
else:
70+
new_user_agent = "strands-agents"
71+
client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
72+
else:
73+
client_config = BotocoreConfig(user_agent_extra="strands-agents")
74+
75+
self.client = session.client(service_name="dynamodb", config=client_config)
76+
self.serializer = TypeSerializer()
77+
self.deserializer = TypeDeserializer()
78+
super().__init__(session_id=session_id, session_repository=self)
79+
80+
def _validate_dynamodb_id(self, id_: str, id_type: _identifier.Identifier) -> str:
81+
"""Validate ID for DynamoDB key structure.
82+
83+
Args:
84+
id_: ID to validate
85+
id_type: Type of ID for error messages
86+
87+
Returns:
88+
Validated ID
89+
90+
Raises:
91+
ValueError: If ID contains characters that would break DynamoDB key structure
92+
"""
93+
if "_" in id_ or "#" in id_:
94+
raise ValueError(f"{id_type.value}_id={id_} | id cannot contain underscore (_) or hash (#) characters")
95+
return id_
96+
97+
def _get_session_pk(self, session_id: str) -> str:
98+
"""Get session partition key."""
99+
session_id = self._validate_dynamodb_id(session_id, _identifier.Identifier.SESSION)
100+
return f"session_{session_id}"
101+
102+
def _get_session_sk(self) -> str:
103+
"""Get session sort key."""
104+
return "session"
105+
106+
def _get_agent_sk(self, agent_id: str) -> str:
107+
"""Get agent sort key."""
108+
agent_id = self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT)
109+
return f"agent_{agent_id}"
110+
111+
def _get_message_sk(self, agent_id: str, message_id: int) -> str:
112+
"""Get message sort key."""
113+
if not isinstance(message_id, int):
114+
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
115+
agent_id = self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT)
116+
return f"agent_{agent_id}#message_{message_id}"
117+
118+
def create_session(self, session: Session, **kwargs: Any) -> Session:
119+
"""Create a new session in DynamoDB."""
120+
pk = self._get_session_pk(session.session_id)
121+
sk = self._get_session_sk()
122+
123+
try:
124+
self.client.put_item(
125+
TableName=self.table_name,
126+
Item={
127+
"PK": {"S": pk},
128+
"SK": {"S": sk},
129+
"entity_type": {"S": "SESSION"},
130+
"data": self.serializer.serialize(session.to_dict()),
131+
},
132+
ConditionExpression="attribute_not_exists(PK)",
133+
)
134+
return session
135+
except ClientError as e:
136+
if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
137+
raise SessionException(f"Session {session.session_id} already exists") from e
138+
raise SessionException(f"DynamoDB error creating session: {e}") from e
139+
140+
def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
141+
"""Read session data from DynamoDB."""
142+
pk = self._get_session_pk(session_id)
143+
sk = self._get_session_sk()
144+
145+
try:
146+
response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}})
147+
if "Item" not in response:
148+
return None
149+
150+
data = self.deserializer.deserialize(response["Item"]["data"])
151+
return Session.from_dict(data)
152+
except ClientError as e:
153+
raise SessionException(f"DynamoDB error reading session: {e}") from e
154+
155+
def delete_session(self, session_id: str, **kwargs: Any) -> None:
156+
"""Delete session and all associated data from DynamoDB."""
157+
pk = self._get_session_pk(session_id)
158+
159+
try:
160+
# Query all items for this session
161+
response = self.client.query(
162+
TableName=self.table_name,
163+
KeyConditionExpression="PK = :pk",
164+
ExpressionAttributeValues={":pk": {"S": pk}},
165+
)
166+
167+
if not response["Items"]:
168+
raise SessionException(f"Session {session_id} does not exist")
169+
170+
# Delete all items in batches
171+
for i in range(0, len(response["Items"]), 25):
172+
batch = response["Items"][i : i + 25]
173+
delete_requests = [{"DeleteRequest": {"Key": {"PK": item["PK"], "SK": item["SK"]}}} for item in batch]
174+
self.client.batch_write_item(RequestItems={self.table_name: delete_requests})
175+
176+
except ClientError as e:
177+
raise SessionException(f"DynamoDB error deleting session: {e}") from e
178+
179+
def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
180+
"""Create a new agent in DynamoDB."""
181+
pk = self._get_session_pk(session_id)
182+
sk = self._get_agent_sk(session_agent.agent_id)
183+
184+
try:
185+
self.client.put_item(
186+
TableName=self.table_name,
187+
Item={
188+
"PK": {"S": pk},
189+
"SK": {"S": sk},
190+
"entity_type": {"S": "AGENT"},
191+
"data": self.serializer.serialize(session_agent.to_dict()),
192+
},
193+
)
194+
except ClientError as e:
195+
raise SessionException(f"DynamoDB error creating agent: {e}") from e
196+
197+
def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]:
198+
"""Read agent data from DynamoDB."""
199+
pk = self._get_session_pk(session_id)
200+
sk = self._get_agent_sk(agent_id)
201+
202+
try:
203+
response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}})
204+
if "Item" not in response:
205+
return None
206+
207+
data = self.deserializer.deserialize(response["Item"]["data"])
208+
return SessionAgent.from_dict(data)
209+
except ClientError as e:
210+
raise SessionException(f"DynamoDB error reading agent: {e}") from e
211+
212+
def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: Any) -> None:
213+
"""Update agent data in DynamoDB."""
214+
previous_agent = self.read_agent(session_id=session_id, agent_id=session_agent.agent_id)
215+
if previous_agent is None:
216+
raise SessionException(f"Agent {session_agent.agent_id} in session {session_id} does not exist")
217+
218+
# Preserve creation timestamp
219+
session_agent.created_at = previous_agent.created_at
220+
221+
pk = self._get_session_pk(session_id)
222+
sk = self._get_agent_sk(session_agent.agent_id)
223+
224+
try:
225+
self.client.put_item(
226+
TableName=self.table_name,
227+
Item={
228+
"PK": {"S": pk},
229+
"SK": {"S": sk},
230+
"entity_type": {"S": "AGENT"},
231+
"data": self.serializer.serialize(session_agent.to_dict()),
232+
},
233+
)
234+
except ClientError as e:
235+
raise SessionException(f"DynamoDB error updating agent: {e}") from e
236+
237+
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
238+
"""Create a new message in DynamoDB."""
239+
pk = self._get_session_pk(session_id)
240+
sk = self._get_message_sk(agent_id, session_message.message_id)
241+
242+
try:
243+
self.client.put_item(
244+
TableName=self.table_name,
245+
Item={
246+
"PK": {"S": pk},
247+
"SK": {"S": sk},
248+
"entity_type": {"S": "MESSAGE"},
249+
"data": self.serializer.serialize(session_message.to_dict()),
250+
},
251+
)
252+
except ClientError as e:
253+
raise SessionException(f"DynamoDB error creating message: {e}") from e
254+
255+
def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]:
256+
"""Read message data from DynamoDB."""
257+
pk = self._get_session_pk(session_id)
258+
sk = self._get_message_sk(agent_id, message_id)
259+
260+
try:
261+
response = self.client.get_item(TableName=self.table_name, Key={"PK": {"S": pk}, "SK": {"S": sk}})
262+
if "Item" not in response:
263+
return None
264+
265+
data = self.deserializer.deserialize(response["Item"]["data"])
266+
return SessionMessage.from_dict(data)
267+
except ClientError as e:
268+
raise SessionException(f"DynamoDB error reading message: {e}") from e
269+
270+
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
271+
"""Update message data in DynamoDB."""
272+
previous_message = self.read_message(
273+
session_id=session_id, agent_id=agent_id, message_id=session_message.message_id
274+
)
275+
if previous_message is None:
276+
raise SessionException(f"Message {session_message.message_id} does not exist")
277+
278+
# Preserve creation timestamp
279+
session_message.created_at = previous_message.created_at
280+
281+
pk = self._get_session_pk(session_id)
282+
sk = self._get_message_sk(agent_id, session_message.message_id)
283+
284+
try:
285+
self.client.put_item(
286+
TableName=self.table_name,
287+
Item={
288+
"PK": {"S": pk},
289+
"SK": {"S": sk},
290+
"entity_type": {"S": "MESSAGE"},
291+
"data": self.serializer.serialize(session_message.to_dict()),
292+
},
293+
)
294+
except ClientError as e:
295+
raise SessionException(f"DynamoDB error updating message: {e}") from e
296+
297+
def list_messages(
298+
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any
299+
) -> List[SessionMessage]:
300+
"""List messages for an agent with pagination from DynamoDB."""
301+
pk = self._get_session_pk(session_id)
302+
agent_prefix = f"agent_{self._validate_dynamodb_id(agent_id, _identifier.Identifier.AGENT)}#message_"
303+
304+
try:
305+
# Query messages for this agent
306+
response = self.client.query(
307+
TableName=self.table_name,
308+
KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)",
309+
ExpressionAttributeValues={":pk": {"S": pk}, ":sk_prefix": {"S": agent_prefix}},
310+
)
311+
312+
# Sort by message ID (extracted from SK)
313+
items = sorted(response["Items"], key=lambda x: int(x["SK"]["S"].split("_")[-1]))
314+
315+
# Apply pagination
316+
if limit is not None:
317+
items = items[offset : offset + limit]
318+
else:
319+
items = items[offset:]
320+
321+
# Convert to SessionMessage objects
322+
messages = []
323+
for item in items:
324+
data = self.deserializer.deserialize(item["data"])
325+
messages.append(SessionMessage.from_dict(data))
326+
327+
return messages
328+
329+
except ClientError as e:
330+
raise SessionException(f"DynamoDB error listing messages: {e}") from e

0 commit comments

Comments
 (0)