Skip to content

Commit 2ece451

Browse files
Fixing test cases for pooling
Signed-off-by: Mohan Lakshmaiah <mohalaks@in.ibm.com>
1 parent 338cf39 commit 2ece451

File tree

5 files changed

+53
-31
lines changed

5 files changed

+53
-31
lines changed

mcpgateway/cache/session_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,12 @@ def _db_cleanup() -> int:
12291229
transport = session_data['transport']
12301230
try:
12311231
if not await transport.is_connected():
1232-
await self.remove_session(session_id)
1232+
if pooled:
1233+
# For pooled sessions, remove from registry but don't disconnect
1234+
await self.remove_session_from_registry_only(session_id)
1235+
else:
1236+
# For non-pooled sessions, full removal with disconnect
1237+
await self.remove_session(session_id)
12331238
continue
12341239

12351240
# Refresh session in database

mcpgateway/main.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,13 +3662,13 @@ async def websocket_endpoint(websocket: WebSocket):
36623662
Args:
36633663
websocket: The WebSocket connection instance.
36643664
"""
3665+
transport = None
3666+
proxy_user = None
3667+
token = None
36653668
try:
36663669
# Authenticate WebSocket connection
36673670
if settings.mcp_client_auth_enabled or settings.auth_required:
36683671
# Extract auth from query params or headers
3669-
token = None
3670-
proxy_user = None
3671-
36723672
# Try to get token from query parameter
36733673
if "token" in websocket.query_params:
36743674
token = websocket.query_params["token"]
@@ -3696,9 +3696,6 @@ async def websocket_endpoint(websocket: WebSocket):
36963696
await websocket.close(code=1008, reason="Invalid authentication")
36973697
return
36983698

3699-
await websocket.accept()
3700-
logger.info("WebSocket connection accepted")
3701-
37023699
# Identify user and server for pooling key
37033700
user_id = proxy_user or "anonymous"
37043701
server_id = websocket.query_params.get("server_id", "default-server")
@@ -3708,21 +3705,16 @@ async def websocket_endpoint(websocket: WebSocket):
37083705
transport = None
37093706
if await should_use_session_pooling(server_id):
37103707
# Use existing or create pooled session
3711-
# Note: WebSocket transport needs the actual WebSocket object, so pooling works differently
37123708
transport = WebSocketTransport(websocket, pooled=True, pool_key=f"{user_id}:{server_id}")
37133709
await transport.connect()
37143710
await session_registry.add_session(transport.session_id, transport, pooled=True)
3715-
logger.info(
3716-
f"Created pooled WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}"
3717-
)
3711+
logger.info(f"Created pooled WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}")
37183712
else:
37193713
# Fallback: create new transport
37203714
transport = WebSocketTransport(websocket)
37213715
await transport.connect()
37223716
await session_registry.add_session(transport.session_id, transport)
3723-
logger.info(
3724-
f"Created new WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}"
3725-
)
3717+
logger.info(f"Created new WebSocket session for user={user_id}, server={server_id}, session={transport.session_id}")
37263718

37273719
while True:
37283720
try:
@@ -3753,10 +3745,8 @@ async def websocket_endpoint(websocket: WebSocket):
37533745
break
37543746
except WebSocketDisconnect:
37553747
logger.info("WebSocket disconnected")
3756-
# Cleanup pooled session if needed
37573748
if transport and hasattr(transport, '_pooled') and transport._pooled:
3758-
# For pooled sessions, we don't immediately remove from registry
3759-
# They get cleaned up by the pool's background task
3749+
# For pooled sessions, we don't immediately remove from registry. They get cleaned up by the pool's background task
37603750
pass
37613751
else:
37623752
# For non-pooled sessions, remove from registry

mcpgateway/services/server_service.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
from datetime import datetime, timezone
1717
from typing import Any, AsyncGenerator, Dict, List, Optional
18+
import builtins
1819

1920
# Third-Party
2021
import httpx
@@ -51,7 +52,7 @@ class ServerNotFoundError(ServerError):
5152
"""Raised when a requested server is not found."""
5253

5354

54-
class PermissionError(ServerError):
55+
class PermissionError(builtins.PermissionError, ServerError):
5556
"""Raised when a user does not have permission to perform an action on a server."""
5657

5758

@@ -647,7 +648,7 @@ async def get_server(self, db: Session, server_id: str) -> ServerRead:
647648
raise ServerNotFoundError(f"Server not found: {server_id}")
648649

649650
try:
650-
effective_strategy = await self.get_session_strategy(db, server_id)
651+
effective_strategy = await self.get_session_strategy(db, server_id, server=server)
651652
logger.debug(f"Server {server_id} effective session strategy: {effective_strategy}")
652653
except Exception as e:
653654
logger.warning(f"Could not determine session strategy for server {server_id}: {e}")
@@ -998,13 +999,14 @@ async def delete_server(self, db: Session, server_id: str, user_email: Optional[
998999
if not server:
9991000
raise ServerNotFoundError(f"Server not found: {server_id}")
10001001

1001-
# Check ownership if user_email provided
1002+
# Always perform ownership check if user_email is provided
10021003
if user_email:
10031004
# First-Party
10041005
from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
10051006

10061007
permission_service = PermissionService(db)
1007-
if not await permission_service.check_resource_ownership(user_email, server):
1008+
can_delete = await permission_service.check_resource_ownership(user_email, server)
1009+
if not can_delete:
10081010
raise PermissionError("Only the owner can delete this server")
10091011

10101012
server_info = {"id": server.id, "name": server.name}
@@ -1147,7 +1149,7 @@ async def _notify_server_deleted(self, server_info: Dict[str, Any]) -> None:
11471149
}
11481150
await self._publish_event(event)
11491151

1150-
async def get_session_strategy(self, db: Session, server_id: str) -> str:
1152+
async def get_session_strategy(self, db: Session, server_id: str, server: Optional[DbServer] = None) -> str:
11511153
"""Determine effective session strategy for server.
11521154
11531155
This method resolves the session strategy for a specific server, taking into account:
@@ -1180,7 +1182,12 @@ async def get_session_strategy(self, db: Session, server_id: str) -> str:
11801182
>>> result == settings.session_pool_strategy
11811183
True
11821184
"""
1183-
server = db.get(DbServer, server_id)
1185+
# server = db.get(DbServer, server_id)
1186+
# if not server:
1187+
# raise ServerNotFoundError(f"Server not found: {server_id}")
1188+
# Allow callers to pass an already-loaded server object to avoid repeated DB lookups.
1189+
if server is None:
1190+
server = db.get(DbServer, server_id)
11841191
if not server:
11851192
raise ServerNotFoundError(f"Server not found: {server_id}")
11861193

mcpgateway/transports/websocket_transport.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ async def connect(self) -> None:
128128
>>> mock_ws.accept.called
129129
True
130130
"""
131-
await self._websocket.accept()
132-
self._connected = True
131+
if not self._connected:
132+
await self._websocket.accept()
133+
self._connected = True
133134

134135
# Start ping task
135136
if settings.websocket_ping_interval > 0:

tests/unit/mcpgateway/services/test_resource_ownership.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,37 @@ async def test_delete_server_non_owner_denied(self, server_service, mock_db_sess
202202
mock_server = MagicMock(spec=Server)
203203
mock_server.id = "server-1"
204204
mock_server.owner_email = "owner@example.com"
205-
205+
mock_server.team_id = None
206+
mock_server.visibility = "private"
207+
mock_server.name = "Test Server"
208+
mock_server.session_pooling_strategy = "inherit"
209+
206210
mock_db_session.get.return_value = mock_server
211+
mock_db_session.rollback = MagicMock()
212+
mock_db_session.commit = MagicMock()
213+
mock_db_session.delete = MagicMock()
207214

208215
with patch('mcpgateway.services.permission_service.PermissionService') as mock_perm_service_class:
209216
mock_perm_service = mock_perm_service_class.return_value
210217
mock_perm_service.check_resource_ownership = AsyncMock(return_value=False)
211-
212-
with pytest.raises(PermissionError, match="Only the owner can delete this server"):
213-
await server_service.delete_server(mock_db_session, "server-1", user_email="other@example.com")
214-
218+
server_service._notify_server_deleted = AsyncMock()
219+
220+
try:
221+
with pytest.raises(PermissionError, match="Only the owner can delete this server"):
222+
await server_service.delete_server(mock_db_session, "server-1", user_email="other@example.com")
223+
except AssertionError as e:
224+
# This will help us understand if we're getting a different error message
225+
print(f"Test failed because: {e}")
226+
raise
227+
except Exception as e:
228+
print(f"Unexpected error: {e}")
229+
raise
230+
231+
# Verify the expectations
215232
mock_db_session.delete.assert_not_called()
216-
233+
mock_db_session.rollback.assert_called_once()
234+
mock_perm_service.check_resource_ownership.assert_called_once_with("other@example.com", mock_server)
235+
mock_db_session.commit.assert_not_called()
217236

218237
class TestToolServiceOwnership:
219238
"""Test ownership checks in ToolService delete/update methods."""

0 commit comments

Comments
 (0)