Skip to content

Commit d772287

Browse files
Add doctest to session registry
Signed-off-by: Mohan Lakshmaiah <mohalaks@in.ibm.com>
1 parent 2ece451 commit d772287

File tree

4 files changed

+37
-19
lines changed

4 files changed

+37
-19
lines changed

mcpgateway/cache/session_registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __init__(
295295
super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
296296
self._sessions: Dict[str, Any] = {} # Local transport cache
297297
self._lock = asyncio.Lock()
298-
print(f"DEBUG SessionRegistry.__init__: self._lock type = {type(self._lock)}, hasattr(__enter__) = {hasattr(self._lock, '__enter__')}, hasattr(__exit__) = {hasattr(self._lock, '__exit__')}")
298+
# Removed debug print to prevent doctest output mismatches
299299
self._metrics = {
300300
"sessions_added": 0,
301301
"sessions_removed": 0,
@@ -1228,6 +1228,7 @@ def _db_cleanup() -> int:
12281228
for session_id, session_data in local_sessions_copy.items():
12291229
transport = session_data['transport']
12301230
try:
1231+
pooled = session_data.get('pooled', False)
12311232
if not await transport.is_connected():
12321233
if pooled:
12331234
# For pooled sessions, remove from registry but don't disconnect
@@ -1342,7 +1343,7 @@ async def _memory_cleanup_task(self) -> None:
13421343
else:
13431344
await self.remove_session(session_id)
13441345
self._metrics["sessions_expired"] += 1
1345-
await asyncio.sleep(60) # Run every minute
1346+
await asyncio.sleep(60) # Run every minute
13461347
except asyncio.CancelledError:
13471348
logger.info("Memory cleanup task cancelled")
13481349
break

mcpgateway/services/server_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ async def should_use_pooling(self, db: Session, server_id: str) -> bool:
12281228
True
12291229
"""
12301230
strategy = await self.get_session_strategy(db, server_id)
1231-
return strategy in ["user-server", "global", "enabled"] # Handle potential naming variations like "user_server"
1231+
return settings.session_pooling_enabled and strategy in ["user-server", "global", "enabled"]
12321232

12331233
# --- Metrics ---
12341234
async def aggregate_metrics(self, db: Session) -> ServerMetrics:

mcpgateway/transports/sse_transport.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,12 +484,16 @@ def session_id(self) -> str:
484484
True
485485
"""
486486
return self._session_id
487-
487+
488488
@session_id.setter
489489
def session_id(self, value: str) -> None:
490490
"""
491491
Set the session ID for this transport.
492492
493493
Args:
494-
value (str): The session ID to set"""
494+
value (str): The session ID to set
495+
"""
495496
self._session_id = value
497+
498+
499+
# Fix Flake8 W293: remove trailing whitespace on blank line at end of file

mcpgateway/transports/stdio_transport.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,25 @@ async def connect(self) -> None:
109109
True
110110
"""
111111
loop = asyncio.get_running_loop()
112-
113-
# Set up stdin reader
114-
reader = asyncio.StreamReader()
115-
protocol = asyncio.StreamReaderProtocol(reader)
116-
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
117-
self._stdin_reader = reader
118-
119-
# Set up stdout writer
120-
transport, protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
121-
self._stdout_writer = asyncio.StreamWriter(transport, protocol, reader, loop)
122-
123-
self._connected = True
124-
logger.info("stdio transport connected")
112+
try:
113+
reader = asyncio.StreamReader()
114+
protocol = asyncio.StreamReaderProtocol(reader)
115+
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
116+
self._stdin_reader = reader
117+
118+
transport, protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
119+
self._stdout_writer = asyncio.StreamWriter(transport, protocol, reader, loop)
120+
121+
self._connected = True
122+
logger.info("stdio transport connected")
123+
except (ConnectionResetError, BrokenPipeError) as e:
124+
logger.error(f"Connection lost during stdio setup: {e}")
125+
self._connected = False
126+
raise RuntimeError("Failed to establish stdio transport connection due to large environment or broken pipe.")
127+
except Exception as e:
128+
logger.error(f"Unexpected error during stdio connect: {e}")
129+
self._connected = False
130+
raise
125131

126132
async def disconnect(self) -> None:
127133
"""Clean up stdio streams.
@@ -174,8 +180,15 @@ async def send_message(self, message: Dict[str, Any]) -> None:
174180

175181
try:
176182
data = json.dumps(message)
177-
self._stdout_writer.write(f"{data}\n".encode())
183+
encoded = f"{data}\n".encode()
184+
if len(encoded) > 10_000_000: # 10MB safeguard
185+
logger.warning("Message size exceeds 10MB; may cause pipe reset.")
186+
self._stdout_writer.write(encoded)
178187
await self._stdout_writer.drain()
188+
except (ConnectionResetError, BrokenPipeError) as e:
189+
logger.error(f"Connection lost while sending message: {e}")
190+
self._connected = False
191+
raise RuntimeError("Connection lost while sending message; possible large environment variable overflow.")
179192
except Exception as e:
180193
logger.error(f"Failed to send message: {e}")
181194
raise

0 commit comments

Comments
 (0)