1010import asyncio
1111import time
1212import logging
13- from typing import Dict , Optional , List , Any
13+ from typing import Dict , Optional , Any
1414from enum import Enum
1515from dataclasses import dataclass
1616from mcpgateway .cache .session_registry import SessionRegistry
2222
2323logger = logging .getLogger (__name__ )
2424
25+
2526class TransportType (Enum ):
2627 """Enumeration of supported transport types."""
2728 SSE = "sse"
2829 WEBSOCKET = "websocket"
2930
31+
3032@dataclass
3133class PoolKey :
3234 """Structured key for session pooling with proper hashing."""
@@ -45,6 +47,7 @@ def __eq__(self, other):
4547 self .server_id == other .server_id and
4648 self .transport_type == other .transport_type )
4749
50+
4851class PooledSession :
4952 """Wrapper around transport for pooling and metrics tracking."""
5053 def __init__ (self , transport : Transport , user_id : str , server_id : str , transport_type : TransportType ):
@@ -118,14 +121,14 @@ def __init__(self, session_registry: SessionRegistry):
118121 if settings .session_pooling_enabled :
119122 self ._start_cleanup_task ()
120123 logger .info ("Session pool initialized with cleanup interval=%s sec" ,
121- settings .session_pool_cleanup_interval )
124+ settings .session_pool_cleanup_interval )
122125
123126 # --------------------------------------------------------------------------
124127 # Core pooling logic with multi-transport support
125128 # --------------------------------------------------------------------------
126129
127130 async def get_or_create_session (self , user_id : str , server_id : str , base_url : str ,
128- transport_type : TransportType ) -> Transport :
131+ transport_type : TransportType ) -> Transport :
129132 """Get an existing session for (user, server, transport) or create a new one."""
130133 if not settings .session_pooling_enabled :
131134 logger .debug ("Session pooling disabled, creating fresh session." )
@@ -160,7 +163,7 @@ async def get_or_create_session(self, user_id: str, server_id: str, base_url: st
160163 return new_session .transport
161164
162165 async def _create_new_session (self , user_id : str , server_id : str , base_url : str ,
163- transport_type : TransportType ) -> PooledSession :
166+ transport_type : TransportType ) -> PooledSession :
164167 """Create and register a brand new transport session."""
165168 try :
166169 # Create transport instance based on type
@@ -185,7 +188,10 @@ async def _create_new_session(self, user_id: str, server_id: str, base_url: str,
185188 self ._metrics ["created" ] += 1
186189
187190 logger .info ("Created new %s session for user=%s server=%s (session_id=%s)" ,
188- transport_type .value , user_id , server_id , transport .session_id )
191+ transport_type .value ,
192+ user_id ,
193+ server_id ,
194+ transport .session_id )
189195 return session
190196
191197 except Exception as e :
@@ -210,14 +216,14 @@ async def _is_session_valid(self, session: PooledSession) -> bool:
210216
211217 if session .idle_time > settings .session_pool_max_idle_time :
212218 logger .debug ("Session %s idle too long (idle_time=%s)." ,
213- session .transport .session_id , session .idle_time )
219+ session .transport .session_id , session .idle_time )
214220 return False
215221
216222 # Additional transport-specific validation
217223 if hasattr (session .transport , 'validate_session' ):
218224 if not await session .transport .validate_session ():
219225 logger .debug ("Session %s failed transport-specific validation." ,
220- session .transport .session_id )
226+ session .transport .session_id )
221227 return False
222228
223229 return True
@@ -331,7 +337,7 @@ def _start_cleanup_task(self):
331337 if not self ._cleanup_task or self ._cleanup_task .done ():
332338 self ._cleanup_task = asyncio .create_task (self .cleanup_expired_sessions ())
333339 logger .info ("Session cleanup task started (interval=%s)." ,
334- settings .session_pool_cleanup_interval )
340+ settings .session_pool_cleanup_interval )
335341
336342 async def shutdown (self ):
337343 """Gracefully stop the session pool and cleanup task."""
@@ -349,4 +355,4 @@ async def shutdown(self):
349355 await self ._cleanup_session (pool_key , session )
350356 self ._pool .clear ()
351357
352- logger .info ("Session pool shut down. Final metrics: %s" , self ._metrics )
358+ logger .info ("Session pool shut down. Final metrics: %s" , self ._metrics )
0 commit comments