@@ -465,40 +465,31 @@ async def create_session(
465465 # 5. Return the session
466466
467467 with self .database_session_factory () as sql_session :
468-
469468 # Fetch app and user states from storage
470469 storage_app_state = sql_session .get (StorageAppState , (app_name ))
471- storage_user_state = sql_session .get (
472- StorageUserState , (app_name , user_id )
473- )
474-
475- app_state = storage_app_state .state if storage_app_state else {}
476- user_state = storage_user_state .state if storage_user_state else {}
477-
478- # Create state tables if not exist
479470 if not storage_app_state :
480471 storage_app_state = StorageAppState (app_name = app_name , state = {})
481472 sql_session .add (storage_app_state )
473+ storage_user_state = sql_session .get (
474+ StorageUserState , (app_name , user_id )
475+ )
482476 if not storage_user_state :
483477 storage_user_state = StorageUserState (
484478 app_name = app_name , user_id = user_id , state = {}
485479 )
486480 sql_session .add (storage_user_state )
487481
488482 # Extract state deltas
489- app_state_delta , user_state_delta , session_state = _extract_state_delta (
490- state
491- )
483+ state_deltas = _session_util .extract_state_delta (state )
484+ app_state_delta = state_deltas ["app" ]
485+ user_state_delta = state_deltas ["user" ]
486+ session_state = state_deltas ["session" ]
492487
493488 # Apply state delta
494- app_state .update (app_state_delta )
495- user_state .update (user_state_delta )
496-
497- # Store app and user state
498489 if app_state_delta :
499- storage_app_state .state = app_state
490+ storage_app_state .state = storage_app_state . state | app_state_delta
500491 if user_state_delta :
501- storage_user_state .state = user_state
492+ storage_user_state .state = storage_user_state . state | user_state_delta
502493
503494 # Store the session
504495 storage_session = StorageSession (
@@ -513,7 +504,9 @@ async def create_session(
513504 sql_session .refresh (storage_session )
514505
515506 # Merge states for response
516- merged_state = _merge_state (app_state , user_state , session_state )
507+ merged_state = _merge_state (
508+ storage_app_state .state , storage_user_state .state , session_state
509+ )
517510 session = storage_session .to_session (state = merged_state )
518511 return session
519512
@@ -536,19 +529,18 @@ async def get_session(
536529 if storage_session is None :
537530 return None
538531
532+ query = sql_session .query (StorageEvent ).filter (
533+ StorageEvent .app_name == app_name ,
534+ StorageEvent .user_id == user_id ,
535+ StorageEvent .session_id == storage_session .id ,
536+ )
537+
539538 if config and config .after_timestamp :
540539 after_dt = datetime .fromtimestamp (config .after_timestamp )
541- timestamp_filter = StorageEvent .timestamp >= after_dt
542- else :
543- timestamp_filter = True
540+ query = query .filter (StorageEvent .timestamp >= after_dt )
544541
545542 storage_events = (
546- sql_session .query (StorageEvent )
547- .filter (StorageEvent .app_name == app_name )
548- .filter (StorageEvent .session_id == storage_session .id )
549- .filter (StorageEvent .user_id == user_id )
550- .filter (timestamp_filter )
551- .order_by (StorageEvent .timestamp .desc ())
543+ query .order_by (StorageEvent .timestamp .desc ())
552544 .limit (
553545 config .num_recent_events
554546 if config and config .num_recent_events
@@ -660,30 +652,21 @@ async def append_event(self, session: Session, event: Event) -> Event:
660652 StorageUserState , (session .app_name , session .user_id )
661653 )
662654
663- app_state = storage_app_state .state if storage_app_state else {}
664- user_state = storage_user_state .state if storage_user_state else {}
665- session_state = storage_session .state
666-
667655 # Extract state delta
668- app_state_delta = {}
669- user_state_delta = {}
670- session_state_delta = {}
671- if event .actions :
672- if event .actions .state_delta :
673- app_state_delta , user_state_delta , session_state_delta = (
674- _extract_state_delta (event .actions .state_delta )
675- )
676-
677- # Merge state and update storage
678- if app_state_delta :
679- app_state .update (app_state_delta )
680- storage_app_state .state = app_state
681- if user_state_delta :
682- user_state .update (user_state_delta )
683- storage_user_state .state = user_state
684- if session_state_delta :
685- session_state .update (session_state_delta )
686- storage_session .state = session_state
656+ if event .actions and event .actions .state_delta :
657+ state_deltas = _session_util .extract_state_delta (
658+ event .actions .state_delta
659+ )
660+ app_state_delta = state_deltas ["app" ]
661+ user_state_delta = state_deltas ["user" ]
662+ session_state_delta = state_deltas ["session" ]
663+ # Merge state and update storage
664+ if app_state_delta :
665+ storage_app_state .state = storage_app_state .state | app_state_delta
666+ if user_state_delta :
667+ storage_user_state .state = storage_user_state .state | user_state_delta
668+ if session_state_delta :
669+ storage_session .state = storage_session .state | session_state_delta
687670
688671 sql_session .add (StorageEvent .from_event (session , event ))
689672
@@ -698,21 +681,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
698681 return event
699682
700683
701- def _extract_state_delta (state : dict [str , Any ]):
702- app_state_delta = {}
703- user_state_delta = {}
704- session_state_delta = {}
705- if state :
706- for key in state .keys ():
707- if key .startswith (State .APP_PREFIX ):
708- app_state_delta [key .removeprefix (State .APP_PREFIX )] = state [key ]
709- elif key .startswith (State .USER_PREFIX ):
710- user_state_delta [key .removeprefix (State .USER_PREFIX )] = state [key ]
711- elif not key .startswith (State .TEMP_PREFIX ):
712- session_state_delta [key ] = state [key ]
713- return app_state_delta , user_state_delta , session_state_delta
714-
715-
716684def _merge_state (app_state , user_state , session_state ):
717685 # Merge states for response
718686 merged_state = copy .deepcopy (session_state )
0 commit comments