@@ -327,6 +327,8 @@ async def initialize(self) -> None:
327327 """
328328 logger .info (f"Initializing session registry with backend: { self ._backend } " )
329329
330+ self ._cleanup_task = None
331+
330332 if self ._backend == "database" :
331333 # Start database cleanup task
332334 self ._cleanup_task = asyncio .create_task (self ._db_cleanup_task ())
@@ -536,16 +538,28 @@ async def get_session(self, session_id: str) -> Any:
536538 if session_entry :
537539 logger .info (f"Session { session_id } exists in local cache" )
538540 # Return the transport object directly, not the dict
539- return session_entry ['transport' ]
541+ # DO NOT overwrite self._lock! That was the bug.
542+ if isinstance (session_entry , dict ):
543+ transport = session_entry .get ('transport' )
544+ if transport is not None : # Check if transport object actually exists in the dict
545+ return transport
546+ else :
547+ # Log if the structure is unexpected (missing 'transport' key)
548+ logger .warning (f"Session { session_id } found in local cache but missing 'transport' key: { session_entry } " )
549+ return None
550+ else :
551+ # For backward compatibility - if it's directly a transport object (shouldn't happen with new add_session)
552+ return session_entry
540553
541- # If not in local cache, check if it exists in shared backend
554+ # If not in local cache (or transport was missing from dict) , check if it exists in shared backend
542555 if self ._backend == "redis" :
543556 try :
557+ # Check if session marker exists in Redis (using EXISTS command might be better than GET if data is large)
544558 session_data = await self ._redis .get (f"mcp:session:{ session_id } " )
545559 if session_data :
546560 logger .info (f"Session { session_id } exists in Redis but not in local cache" )
547561 # Return None since we don't have the transport locally
548- return None
562+ return None
549563 except Exception as e :
550564 logger .error (f"Redis error checking session { session_id } : { e } " )
551565 return None
@@ -571,7 +585,6 @@ def _db_check() -> bool:
571585 """
572586 db_session = next (get_db ())
573587 try :
574- # Query with pooled flag if needed
575588 record = db_session .query (SessionRecord ).filter (SessionRecord .session_id == session_id ).first ()
576589 return record is not None
577590 finally :
@@ -847,8 +860,8 @@ def _db_add() -> None:
847860 except Exception as e :
848861 logger .error (f"Database error during broadcast: { e } " )
849862
850- def get_session_sync (self , session_id : str ) -> Any :
851- """Get session synchronously from local cache only.
863+ def get_session_sync (self , session_id : str ) -> Optional [ SSETransport ] :
864+ """Get session transport synchronously from local cache only.
852865
853866 This is a non-blocking method that only checks the local cache,
854867 not the distributed backend. Use this when you need quick access
@@ -863,19 +876,19 @@ def get_session_sync(self, session_id: str) -> Any:
863876 Examples:
864877 >>> from mcpgateway.cache.session_registry import SessionRegistry
865878 >>> import asyncio
866- >>>
879+
867880 >>> class MockTransport:
868881 ... pass
869- >>>
882+
870883 >>> reg = SessionRegistry()
871884 >>> transport = MockTransport()
872885 >>> asyncio.run(reg.add_session('sync-test', transport))
873- >>>
886+
874887 >>> # Synchronous lookup
875888 >>> found = reg.get_session_sync('sync-test')
876889 >>> found is transport
877890 True
878- >>>
891+
879892 >>> # Not found
880893 >>> reg.get_session_sync('nonexistent') is None
881894 True
@@ -884,8 +897,17 @@ def get_session_sync(self, session_id: str) -> Any:
884897 if self ._backend == "none" :
885898 return None
886899
887- return self ._sessions .get (session_id )
888-
900+ # For sync method, just access directly without lock to avoid async/sync mixing
901+ session_entry = self ._sessions .get (session_id )
902+ if session_entry :
903+ # Handle the new dict structure: {'transport': t, 'pooled': p, 'created_at': time}
904+ if isinstance (session_entry , dict ):
905+ return session_entry .get ('transport' ) # Return the transport object from the dict
906+ else :
907+ # For backward compatibility - if it's directly a transport object
908+ return session_entry
909+ return None
910+
889911 async def respond (
890912 self ,
891913 server_id : Optional [str ],
@@ -1108,11 +1130,24 @@ async def _refresh_redis_sessions(self) -> None:
11081130 """
11091131 try :
11101132 # Check all local sessions
1111- local_transports = {}
1133+ local_sessions_copy = {}
11121134 async with self ._lock :
1113- local_transports = self ._sessions .copy ()
1135+ # Create a copy of session data for checking
1136+ for sid , entry in self ._sessions .items ():
1137+ if isinstance (entry , dict ):
1138+ local_sessions_copy [sid ] = {
1139+ 'transport' : entry ['transport' ],
1140+ 'pooled' : entry .get ('pooled' , False )
1141+ }
1142+ else :
1143+ # For backward compatibility with direct transport storage
1144+ local_sessions_copy [sid ] = {
1145+ 'transport' : entry ,
1146+ 'pooled' : False
1147+ }
11141148
1115- for session_id , transport in local_transports .items ():
1149+ for session_id , session_data in local_sessions_copy .items ():
1150+ transport = session_data ['transport' ]
11161151 try :
11171152 if await transport .is_connected ():
11181153 # Refresh TTL in Redis
@@ -1180,11 +1215,24 @@ def _db_cleanup() -> int:
11801215 logger .info (f"Cleaned up { deleted } expired database sessions" )
11811216
11821217 # Check local sessions against database
1183- local_transports = {}
1218+ local_sessions_copy = {}
11841219 async with self ._lock :
1185- local_transports = self ._sessions .copy ()
1220+ # Create a copy of session data for checking
1221+ for sid , entry in self ._sessions .items ():
1222+ if isinstance (entry , dict ):
1223+ local_sessions_copy [sid ] = {
1224+ 'transport' : entry ['transport' ],
1225+ 'pooled' : entry .get ('pooled' , False )
1226+ }
1227+ else :
1228+ # For backward compatibility with direct transport storage
1229+ local_sessions_copy [sid ] = {
1230+ 'transport' : entry ,
1231+ 'pooled' : False
1232+ }
11861233
1187- for session_id , transport in local_transports .items ():
1234+ for session_id , session_data in local_sessions_copy .items ():
1235+ transport = session_data ['transport' ]
11881236 try :
11891237 if not await transport .is_connected ():
11901238 await self .remove_session (session_id )
@@ -1567,7 +1615,7 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
15671615 # ------------------------------
15681616 # Observability
15691617 # ------------------------------
1570-
1618+
15711619 def get_metrics (self ) -> Dict [str , int ]:
15721620 """
15731621 Retrieve internal metrics counters for the session registry.
0 commit comments