Skip to content

Commit b675aca

Browse files
flake8 fixes
Signed-off-by: Mohan Lakshmaiah <mohalaks@in.ibm.com>
1 parent f4d0bd8 commit b675aca

File tree

2 files changed

+81
-17
lines changed

2 files changed

+81
-17
lines changed

mcpgateway/cache/session_pool.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@ class PoolKey:
3737
transport_type: TransportType
3838

3939
def __hash__(self):
40-
"""Compute hash based on user_id, server_id, and transport_type."""
40+
"""Compute hash based on user_id, server_id, and transport_type.
41+
Returns:
42+
int: Hash value"""
4143
return hash((self.user_id, self.server_id, self.transport_type))
4244

4345
def __eq__(self, other):
44-
"""Equality check based on user_id, server_id, and transport_type."""
46+
"""Equality check based on user_id, server_id, and transport_type.
47+
Args:
48+
other (PoolKey): Another PoolKey instance to compare with.
49+
Returns:
50+
bool: True if equal, False otherwise."""
4551
return (isinstance(other, PoolKey) and
4652
self.user_id == other.user_id and
4753
self.server_id == other.server_id and
@@ -51,7 +57,12 @@ def __eq__(self, other):
5157
class PooledSession:
5258
"""Wrapper around transport for pooling and metrics tracking."""
5359
def __init__(self, transport: Transport, user_id: str, server_id: str, transport_type: TransportType):
54-
"""Initialize pooled session wrapper."""
60+
"""Initialize pooled session wrapper.
61+
Args:
62+
transport (Transport): The transport instance
63+
user_id (str): Identifier for the user
64+
server_id (str): Identifier for the server
65+
transport_type (TransportType): Type of transport"""
5566
self.transport = transport
5667
self.user_id = user_id
5768
self.server_id = server_id
@@ -65,12 +76,16 @@ def __init__(self, transport: Transport, user_id: str, server_id: str, transport
6576

6677
@property
6778
def age(self) -> float:
68-
"""Get the age of the session in seconds."""
79+
"""Get the age of the session in seconds.
80+
Returns:
81+
float: Age of the session in seconds."""
6982
return time.time() - self.created_at
7083

7184
@property
7285
def idle_time(self) -> float:
73-
"""Get the idle time of the session in seconds."""
86+
"""Get the idle time of the session in seconds.
87+
Returns:
88+
float: Idle time of the session in seconds."""
7489
return time.time() - self.last_used
7590

7691
def capture_state(self) -> None:
@@ -103,7 +118,9 @@ class SessionPool:
103118
}
104119

105120
def __init__(self, session_registry: SessionRegistry):
106-
"""Initialize the session pool."""
121+
"""Initialize the session pool.
122+
Args:
123+
session_registry (SessionRegistry): Registry to track active sessions"""
107124
self._registry = session_registry
108125
self._pool: Dict[PoolKey, PooledSession] = {}
109126
self._lock = asyncio.Lock()
@@ -129,7 +146,21 @@ def __init__(self, session_registry: SessionRegistry):
129146

130147
async def get_or_create_session(self, user_id: str, server_id: str, base_url: str,
131148
transport_type: TransportType) -> Transport:
132-
"""Get an existing session for (user, server, transport) or create a new one."""
149+
"""
150+
Get an existing session for (user, server, transport) or create a new one.
151+
152+
Args:
153+
user_id: Identifier for the user
154+
server_id: Identifier for the server
155+
base_url: Base URL for transport connection
156+
transport_type: Type of transport to use
157+
158+
Returns:
159+
Transport: An active transport session
160+
161+
Raises:
162+
Exception: If session creation fails
163+
"""
133164
if not settings.session_pooling_enabled:
134165
logger.debug("Session pooling disabled, creating fresh session.")
135166
return await self._create_new_session(user_id, server_id, base_url, transport_type)
@@ -164,7 +195,22 @@ async def get_or_create_session(self, user_id: str, server_id: str, base_url: st
164195

165196
async def _create_new_session(self, user_id: str, server_id: str, base_url: str,
166197
transport_type: TransportType) -> PooledSession:
167-
"""Create and register a brand new transport session."""
198+
"""
199+
Create and register a brand new transport session.
200+
201+
Args:
202+
user_id: Identifier for the user
203+
server_id: Identifier for the server
204+
base_url: Base URL for transport connection
205+
transport_type: Type of transport to create
206+
207+
Raises:
208+
ValueError: If the transport type is unsupported
209+
210+
Returns:
211+
PooledSession: The newly created pooled session
212+
213+
"""
168214
try:
169215
# Create transport instance based on type
170216
transport_class = self.TRANSPORT_CLASSES.get(transport_type)
@@ -188,9 +234,9 @@ async def _create_new_session(self, user_id: str, server_id: str, base_url: str,
188234
self._metrics["created"] += 1
189235

190236
logger.info("Created new %s session for user=%s server=%s (session_id=%s)",
191-
transport_type.value,
192-
user_id,
193-
server_id,
237+
transport_type.value,
238+
user_id,
239+
server_id,
194240
transport.session_id)
195241
return session
196242

@@ -204,7 +250,11 @@ async def _create_new_session(self, user_id: str, server_id: str, base_url: str,
204250
# --------------------------------------------------------------------------
205251

206252
async def _is_session_valid(self, session: PooledSession) -> bool:
207-
"""Check whether a pooled session is still alive and eligible for reuse."""
253+
"""Check whether a pooled session is still alive and eligible for reuse.
254+
Args:
255+
session (PooledSession): The session to validate
256+
Returns:
257+
bool: True if valid, False otherwise"""
208258
try:
209259
if not await session.transport.is_connected():
210260
logger.debug("Session %s disconnected.", session.transport.session_id)
@@ -233,7 +283,10 @@ async def _is_session_valid(self, session: PooledSession) -> bool:
233283
return False
234284

235285
async def _cleanup_session(self, pool_key: PoolKey, session: PooledSession) -> None:
236-
"""Safely close and remove a single session."""
286+
"""Safely close and remove a single session.
287+
Args:
288+
pool_key (PoolKey): The key identifying the session in the pool
289+
session (PooledSession): The session to clean up"""
237290
try:
238291
# Capture final state before cleanup
239292
session.capture_state()
@@ -286,7 +339,10 @@ async def cleanup_expired_sessions(self):
286339
# --------------------------------------------------------------------------
287340

288341
async def capture_all_states(self) -> Dict[str, Any]:
289-
"""Capture states from all active sessions for persistence."""
342+
"""Capture states from all active sessions for persistence.
343+
Returns:
344+
Dict[str, Any]: Mapping of session IDs to their captured states
345+
"""
290346
states = {}
291347
async with self._lock:
292348
for pool_key, session in self._pool.items():
@@ -303,7 +359,12 @@ async def capture_all_states(self) -> Dict[str, Any]:
303359
return states
304360

305361
async def restore_session_state(self, session_id: str, state: Dict[str, Any]) -> bool:
306-
"""Restore state to a specific session."""
362+
"""Restore state to a specific session.
363+
Args:
364+
session_id (str): The session ID to restore state to
365+
state (Dict[str, Any]): The state data to restore
366+
Returns:
367+
bool: True if restoration was successful, False otherwise"""
307368
async with self._lock:
308369
for pool_key, session in self._pool.items():
309370
if session.transport.session_id == session_id:
@@ -319,7 +380,10 @@ async def restore_session_state(self, session_id: str, state: Dict[str, Any]) ->
319380
# --------------------------------------------------------------------------
320381

321382
def get_pool_stats(self) -> Dict[str, Any]:
322-
"""Get comprehensive pool statistics."""
383+
"""Get comprehensive pool statistics.
384+
Returns:
385+
Dict[str, Any]: Current pool statistics
386+
"""
323387
stats = {
324388
"metrics": self._metrics.copy(),
325389
"active_sessions": len(self._pool),

mcpgateway/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4007,7 +4007,7 @@ async def websocket_endpoint(websocket: WebSocket):
40074007
# Identify user and server for pooling key
40084008
user_id = proxy_user or "anonymous"
40094009
server_id = websocket.query_params.get("server_id", "default-server")
4010-
base_url = f"ws://localhost:{settings.port}{settings.app_root_path}/ws"
4010+
# base_url = f"ws://localhost:{settings.port}{settings.app_root_path}/ws"
40114011

40124012
# Session Pooling logic
40134013
transport = None

0 commit comments

Comments
 (0)