Skip to content

Commit 8b3ed05

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Refactor and fix state management in the session service
Also refactoring the test cases to focus on the expected behaviors PiperOrigin-RevId: 820734484
1 parent cf34032 commit 8b3ed05

File tree

4 files changed

+223
-250
lines changed

4 files changed

+223
-250
lines changed

src/google/adk/sessions/_session_util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Type
2020
from typing import TypeVar
2121

22+
from .state import State
23+
2224
M = TypeVar("M")
2325

2426

@@ -29,3 +31,19 @@ def decode_model(
2931
if data is None:
3032
return None
3133
return model_cls.model_validate(data)
34+
35+
36+
def extract_state_delta(
37+
state: dict[str, Any],
38+
) -> dict[str, dict[str, Any]]:
39+
"""Extracts app, user, and session state deltas from a state dictionary."""
40+
deltas = {"app": {}, "user": {}, "session": {}}
41+
if state:
42+
for key in state.keys():
43+
if key.startswith(State.APP_PREFIX):
44+
deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key]
45+
elif key.startswith(State.USER_PREFIX):
46+
deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key]
47+
elif not key.startswith(State.TEMP_PREFIX):
48+
deltas["session"][key] = state[key]
49+
return deltas

src/google/adk/sessions/database_session_service.py

Lines changed: 34 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -465,40 +465,31 @@ async def create_session(
465465
# 5. Return the session
466466

467467
with self.database_session_factory() as sql_session:
468-
469468
# Fetch app and user states from storage
470469
storage_app_state = sql_session.get(StorageAppState, (app_name))
471-
storage_user_state = sql_session.get(
472-
StorageUserState, (app_name, user_id)
473-
)
474-
475-
app_state = storage_app_state.state if storage_app_state else {}
476-
user_state = storage_user_state.state if storage_user_state else {}
477-
478-
# Create state tables if not exist
479470
if not storage_app_state:
480471
storage_app_state = StorageAppState(app_name=app_name, state={})
481472
sql_session.add(storage_app_state)
473+
storage_user_state = sql_session.get(
474+
StorageUserState, (app_name, user_id)
475+
)
482476
if not storage_user_state:
483477
storage_user_state = StorageUserState(
484478
app_name=app_name, user_id=user_id, state={}
485479
)
486480
sql_session.add(storage_user_state)
487481

488482
# Extract state deltas
489-
app_state_delta, user_state_delta, session_state = _extract_state_delta(
490-
state
491-
)
483+
state_deltas = _session_util.extract_state_delta(state)
484+
app_state_delta = state_deltas["app"]
485+
user_state_delta = state_deltas["user"]
486+
session_state = state_deltas["session"]
492487

493488
# Apply state delta
494-
app_state.update(app_state_delta)
495-
user_state.update(user_state_delta)
496-
497-
# Store app and user state
498489
if app_state_delta:
499-
storage_app_state.state = app_state
490+
storage_app_state.state = storage_app_state.state | app_state_delta
500491
if user_state_delta:
501-
storage_user_state.state = user_state
492+
storage_user_state.state = storage_user_state.state | user_state_delta
502493

503494
# Store the session
504495
storage_session = StorageSession(
@@ -513,7 +504,9 @@ async def create_session(
513504
sql_session.refresh(storage_session)
514505

515506
# Merge states for response
516-
merged_state = _merge_state(app_state, user_state, session_state)
507+
merged_state = _merge_state(
508+
storage_app_state.state, storage_user_state.state, session_state
509+
)
517510
session = storage_session.to_session(state=merged_state)
518511
return session
519512

@@ -536,19 +529,18 @@ async def get_session(
536529
if storage_session is None:
537530
return None
538531

532+
query = sql_session.query(StorageEvent).filter(
533+
StorageEvent.app_name == app_name,
534+
StorageEvent.user_id == user_id,
535+
StorageEvent.session_id == storage_session.id,
536+
)
537+
539538
if config and config.after_timestamp:
540539
after_dt = datetime.fromtimestamp(config.after_timestamp)
541-
timestamp_filter = StorageEvent.timestamp >= after_dt
542-
else:
543-
timestamp_filter = True
540+
query = query.filter(StorageEvent.timestamp >= after_dt)
544541

545542
storage_events = (
546-
sql_session.query(StorageEvent)
547-
.filter(StorageEvent.app_name == app_name)
548-
.filter(StorageEvent.session_id == storage_session.id)
549-
.filter(StorageEvent.user_id == user_id)
550-
.filter(timestamp_filter)
551-
.order_by(StorageEvent.timestamp.desc())
543+
query.order_by(StorageEvent.timestamp.desc())
552544
.limit(
553545
config.num_recent_events
554546
if config and config.num_recent_events
@@ -660,30 +652,21 @@ async def append_event(self, session: Session, event: Event) -> Event:
660652
StorageUserState, (session.app_name, session.user_id)
661653
)
662654

663-
app_state = storage_app_state.state if storage_app_state else {}
664-
user_state = storage_user_state.state if storage_user_state else {}
665-
session_state = storage_session.state
666-
667655
# Extract state delta
668-
app_state_delta = {}
669-
user_state_delta = {}
670-
session_state_delta = {}
671-
if event.actions:
672-
if event.actions.state_delta:
673-
app_state_delta, user_state_delta, session_state_delta = (
674-
_extract_state_delta(event.actions.state_delta)
675-
)
676-
677-
# Merge state and update storage
678-
if app_state_delta:
679-
app_state.update(app_state_delta)
680-
storage_app_state.state = app_state
681-
if user_state_delta:
682-
user_state.update(user_state_delta)
683-
storage_user_state.state = user_state
684-
if session_state_delta:
685-
session_state.update(session_state_delta)
686-
storage_session.state = session_state
656+
if event.actions and event.actions.state_delta:
657+
state_deltas = _session_util.extract_state_delta(
658+
event.actions.state_delta
659+
)
660+
app_state_delta = state_deltas["app"]
661+
user_state_delta = state_deltas["user"]
662+
session_state_delta = state_deltas["session"]
663+
# Merge state and update storage
664+
if app_state_delta:
665+
storage_app_state.state = storage_app_state.state | app_state_delta
666+
if user_state_delta:
667+
storage_user_state.state = storage_user_state.state | user_state_delta
668+
if session_state_delta:
669+
storage_session.state = storage_session.state | session_state_delta
687670

688671
sql_session.add(StorageEvent.from_event(session, event))
689672

@@ -698,21 +681,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
698681
return event
699682

700683

701-
def _extract_state_delta(state: dict[str, Any]):
702-
app_state_delta = {}
703-
user_state_delta = {}
704-
session_state_delta = {}
705-
if state:
706-
for key in state.keys():
707-
if key.startswith(State.APP_PREFIX):
708-
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
709-
elif key.startswith(State.USER_PREFIX):
710-
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
711-
elif not key.startswith(State.TEMP_PREFIX):
712-
session_state_delta[key] = state[key]
713-
return app_state_delta, user_state_delta, session_state_delta
714-
715-
716684
def _merge_state(app_state, user_state, session_state):
717685
# Merge states for response
718686
merged_state = copy.deepcopy(session_state)

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from typing_extensions import override
2424

25+
from . import _session_util
2526
from ..events.event import Event
2627
from .base_session_service import BaseSessionService
2728
from .base_session_service import GetSessionConfig
@@ -88,6 +89,17 @@ def _create_session_impl(
8889
state: Optional[dict[str, Any]] = None,
8990
session_id: Optional[str] = None,
9091
) -> Session:
92+
state_deltas = _session_util.extract_state_delta(state)
93+
app_state_delta = state_deltas['app']
94+
user_state_delta = state_deltas['user']
95+
session_state = state_deltas['session']
96+
if app_state_delta:
97+
self.app_state.setdefault(app_name, {}).update(app_state_delta)
98+
if user_state_delta:
99+
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
100+
user_state_delta
101+
)
102+
91103
session_id = (
92104
session_id.strip()
93105
if session_id and session_id.strip()
@@ -97,7 +109,7 @@ def _create_session_impl(
97109
app_name=app_name,
98110
user_id=user_id,
99111
id=session_id,
100-
state=state or {},
112+
state=session_state or {},
101113
last_update_time=time.time(),
102114
)
103115

@@ -174,11 +186,13 @@ def _get_session_impl(
174186
if i >= 0:
175187
copied_session.events = copied_session.events[i + 1 :]
176188

189+
# Return a copy of the session object with merged state.
177190
return self._merge_state(app_name, user_id, copied_session)
178191

179192
def _merge_state(
180193
self, app_name: str, user_id: str, copied_session: Session
181194
) -> Session:
195+
"""Merges app and user state into session state."""
182196
# Merge app state
183197
if app_name in self.app_state:
184198
for key in self.app_state[app_name].keys():
@@ -269,11 +283,9 @@ def _delete_session_impl(
269283

270284
@override
271285
async def append_event(self, session: Session, event: Event) -> Event:
272-
# Update the in-memory session.
273-
await super().append_event(session=session, event=event)
274-
session.last_update_time = event.timestamp
286+
if event.partial:
287+
return event
275288

276-
# Update the storage session
277289
app_name = session.app_name
278290
user_id = session.user_id
279291
session_id = session.id
@@ -293,21 +305,29 @@ def _warning(message: str) -> None:
293305
_warning(f'session_id {session_id} not in sessions[app_name][user_id]')
294306
return event
295307

296-
if event.actions and event.actions.state_delta:
297-
for key in event.actions.state_delta:
298-
if key.startswith(State.APP_PREFIX):
299-
self.app_state.setdefault(app_name, {})[
300-
key.removeprefix(State.APP_PREFIX)
301-
] = event.actions.state_delta[key]
302-
303-
if key.startswith(State.USER_PREFIX):
304-
self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
305-
key.removeprefix(State.USER_PREFIX)
306-
] = event.actions.state_delta[key]
308+
# Update the in-memory session.
309+
await super().append_event(session=session, event=event)
310+
session.last_update_time = event.timestamp
307311

312+
# Update the storage session
308313
storage_session = self.sessions[app_name][user_id].get(session_id)
309-
await super().append_event(session=storage_session, event=event)
310-
314+
storage_session.events.append(event)
311315
storage_session.last_update_time = event.timestamp
312316

317+
if event.actions and event.actions.state_delta:
318+
state_deltas = _session_util.extract_state_delta(
319+
event.actions.state_delta
320+
)
321+
app_state_delta = state_deltas['app']
322+
user_state_delta = state_deltas['user']
323+
session_state_delta = state_deltas['session']
324+
if app_state_delta:
325+
self.app_state.setdefault(app_name, {}).update(app_state_delta)
326+
if user_state_delta:
327+
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
328+
user_state_delta
329+
)
330+
if session_state_delta:
331+
storage_session.state.update(session_state_delta)
332+
313333
return event

0 commit comments

Comments
 (0)