1616
1717import asyncio
1818import logging
19- from datetime import datetime
19+ from datetime import datetime , timezone
2020from typing import Any , AsyncGenerator , Dict , List , Optional , Set
2121
2222import httpx
2323from mcp import ClientSession
2424from mcp .client .sse import sse_client
2525from sqlalchemy import select
26- from sqlalchemy .exc import IntegrityError
2726from sqlalchemy .orm import Session
2827
2928from mcpgateway .config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122121
123122 Raises:
124123 GatewayNameConflictError: If gateway name already exists
125- GatewayError: If registration fails
126124 """
127125 try :
128126 # Check for name conflicts (both active and inactive)
@@ -138,12 +136,43 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138136 auth_type = getattr (gateway , "auth_type" , None )
139137 auth_value = getattr (gateway , "auth_value" , {})
140138
141- # Initialize connection and get capabilities
142139 capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value )
143140
141+ all_names = [td .name for td in tools ]
142+
143+ existing_tools = db .execute (select (DbTool ).where (DbTool .name .in_ (all_names ))).scalars ().all ()
144+ existing_tool_names = [tool .name for tool in existing_tools ]
145+
146+ tools = [
147+ DbTool (
148+ name = tool .name ,
149+ url = str (gateway .url ),
150+ description = tool .description ,
151+ integration_type = tool .integration_type ,
152+ request_type = tool .request_type ,
153+ headers = tool .headers ,
154+ input_schema = tool .input_schema ,
155+ jsonpath_filter = tool .jsonpath_filter ,
156+ auth_type = auth_type ,
157+ auth_value = auth_value ,
158+ )
159+ for tool in tools
160+ ]
161+
162+ existing_tools = [tool for tool in tools if tool .name in existing_tool_names ]
163+ new_tools = [tool for tool in tools if tool .name not in existing_tool_names ]
164+
144165 # Create DB model
145166 db_gateway = DbGateway (
146- name = gateway .name , url = str (gateway .url ), description = gateway .description , capabilities = capabilities , last_seen = datetime .utcnow (), auth_type = auth_type , auth_value = auth_value
167+ name = gateway .name ,
168+ url = str (gateway .url ),
169+ description = gateway .description ,
170+ capabilities = capabilities ,
171+ last_seen = datetime .now (timezone .utc ),
172+ auth_type = auth_type ,
173+ auth_value = auth_value ,
174+ tools = new_tools ,
175+ # federated_tools=existing_tools + new_tools
147176 )
148177
149178 # Add to DB
@@ -157,23 +186,20 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157186 # Notify subscribers
158187 await self ._notify_gateway_added (db_gateway )
159188
160- inserted_gateway = db .execute (select (DbGateway ).where (DbGateway .name == gateway .name )).scalar_one_or_none ()
161- inserted_gateway_id = inserted_gateway .id
162-
163- logger .info (f"Registered gateway: { gateway .name } " )
164-
165- for tool in tools :
166- tool .gateway_id = inserted_gateway_id
167- await self .tool_service .register_tool (db = db , tool = tool )
168-
169189 return GatewayRead .model_validate (gateway )
170-
171- except IntegrityError :
172- db .rollback ()
173- raise GatewayError (f"Gateway already exists: { gateway .name } " )
174- except Exception as e :
175- db .rollback ()
176- raise GatewayError (f"Failed to register gateway: { str (e )} " )
190+ except* ValueError as ve :
191+ logger .error ("ValueErrors in group: %s" , ve .exceptions )
192+ except* RuntimeError as re :
193+ logger .error ("RuntimeErrors in group: %s" , re .exceptions )
194+ except* BaseException as other : # catches every other sub-exception
195+ logger .error ("Other grouped errors: %s" , other .exceptions )
196+ # except IntegrityError as ex:
197+ # logger.error(f"Error adding gateway: {ex}")
198+ # db.rollback()
199+ # raise GatewayError(f"Gateway already exists: {gateway.name}")
200+ # except Exception as e:
201+ # db.rollback()
202+ # raise GatewayError(f"Failed to register gateway: {str(e)}")
177203
178204 async def list_gateways (self , db : Session , include_inactive : bool = False ) -> List [GatewayRead ]:
179205 """List all registered gateways.
@@ -393,14 +419,6 @@ async def delete_gateway(self, db: Session, gateway_id: int) -> None:
393419 # Store gateway info for notification before deletion
394420 gateway_info = {"id" : gateway .id , "name" : gateway .name , "url" : gateway .url }
395421
396- # Remove associated tools
397- try :
398- db .query (DbTool ).filter (DbTool .gateway_id == gateway_id ).delete ()
399- db .commit ()
400- logger .info (f"Deleted tools associated with gateway: { gateway .name } " )
401- except Exception as ex :
402- logger .warning (f"No tools found: { ex } " )
403-
404422 # Hard delete gateway
405423 db .delete (gateway )
406424 db .commit ()
@@ -457,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
457475 raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
458476
459477 async def check_health_of_gateways (self , gateways : List [DbGateway ]) -> bool :
460- """Health check for gateways
478+ """Health check for a list of gateways.
479+
480+ Deactivates gateway if gateway is not healthy.
461481
462482 Args:
463- gateways: Gateways to check
483+ gateways (List[DbGateway]): List of gateways to check if healthy
464484
465485 Returns:
466- True if gateway is healthy
486+ bool: True if all active gateways are healthy
467487 """
468- for gateway in gateways :
469- if not gateway .is_active :
470- return False
488+ # Reuse a single HTTP client for all requests
489+ async with httpx .AsyncClient () as client :
490+ for gateway in gateways :
491+ # Inactive gateways are unhealthy
492+ if not gateway .is_active :
493+ continue
471494
472- try :
473- # Try to initialize connection
474- await self ._initialize_gateway (gateway .url , gateway .auth_value )
495+ try :
496+ # Ensure auth_value is a dict
497+ auth_data = gateway .auth_value or {}
498+ headers = decode_auth (auth_data )
499+
500+ # Perform the GET and raise on 4xx/5xx
501+ async with client .stream ("GET" , gateway .url , headers = headers ) as response :
502+ # This will raise immediately if status is 4xx/5xx
503+ response .raise_for_status ()
475504
476- # Update last seen
477- gateway .last_seen = datetime .utcnow ()
478- return True
505+ # Mark successful check
506+ gateway .last_seen = datetime .utcnow ()
507+
508+ except Exception :
509+ with SessionLocal () as db :
510+ await self .toggle_gateway_status (db = db , gateway_id = gateway .id , activate = False )
479511
480- except Exception :
481- return False
512+ # All gateways passed
513+ return True
482514
483515 async def aggregate_capabilities (self , db : Session ) -> Dict [str , Any ]:
484516 """Aggregate capabilities from all gateways.
@@ -579,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
579611 raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str (e )} " )
580612
581613 def _get_active_gateways (self ) -> list [DbGateway ]:
582- """Sync function for database operations (runs in thread)."""
614+ """Sync function for database operations (runs in thread).
615+
616+ Returns:
617+ List[DbGateway]: List of active gateways
618+ """
583619 with SessionLocal () as db :
584620 return db .execute (select (DbGateway ).where (DbGateway .is_active )).scalars ().all ()
585621
@@ -590,9 +626,9 @@ async def _run_health_checks(self) -> None:
590626 # Run sync database code in a thread
591627 gateways = await asyncio .to_thread (self ._get_active_gateways )
592628
593- # Async health checks (non-blocking)
594- await self . check_health_of_gateways ( gateways )
595-
629+ if len ( gateways ) > 0 :
630+ # Async health checks (non-blocking )
631+ await self . check_health_of_gateways ( gateways )
596632 except Exception as e :
597633 logger .error (f"Health check run failed: { str (e )} " )
598634
0 commit comments