2323from mcp import ClientSession
2424from mcp .client .sse import sse_client
2525from sqlalchemy import select
26- from sqlalchemy .exc import IntegrityError
2726from sqlalchemy .orm import Session
2827
2928from mcpgateway .config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122121
123122 Raises:
124123 GatewayNameConflictError: If gateway name already exists
125- GatewayError: If registration fails
126124 """
127125 try :
128126 # Check for name conflicts (both active and inactive)
@@ -138,13 +136,17 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138136 auth_type = getattr (gateway , "auth_type" , None )
139137 auth_value = getattr (gateway , "auth_value" , {})
140138
141- # Initialize connection and get capabilities
142139 capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value )
143-
140+
141+ all_names = [td .name for td in tools ]
142+
143+ existing_tools = db .execute (select (DbTool ).where (DbTool .name .in_ (all_names ))).scalars ().all ()
144+ existing_tool_names = [tool .name for tool in existing_tools ]
145+
144146 tools = [
145147 DbTool (
146148 name = tool .name ,
147- url = tool .url ,
149+ url = str ( gateway .url ) ,
148150 description = tool .description ,
149151 integration_type = tool .integration_type ,
150152 request_type = tool .request_type ,
@@ -157,6 +159,9 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157159 for tool in tools
158160 ]
159161
162+ existing_tools = [tool for tool in tools if tool .name in existing_tool_names ]
163+ new_tools = [tool for tool in tools if tool .name not in existing_tool_names ]
164+
160165 # Create DB model
161166 db_gateway = DbGateway (
162167 name = gateway .name ,
@@ -166,7 +171,8 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
166171 last_seen = datetime .now (timezone .utc ),
167172 auth_type = auth_type ,
168173 auth_value = auth_value ,
169- tools = tools ,
174+ tools = new_tools ,
175+ # federated_tools=existing_tools + new_tools
170176 )
171177
172178 # Add to DB
@@ -181,12 +187,19 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
181187 await self ._notify_gateway_added (db_gateway )
182188
183189 return GatewayRead .model_validate (gateway )
184- except IntegrityError :
185- db .rollback ()
186- raise GatewayError (f"Gateway already exists: { gateway .name } " )
187- except Exception as e :
188- db .rollback ()
189- raise GatewayError (f"Failed to register gateway: { str (e )} " )
190+ except* ValueError as ve :
191+ logger .error ("ValueErrors in group: %s" , ve .exceptions )
192+ except* RuntimeError as re :
193+ logger .error ("RuntimeErrors in group: %s" , re .exceptions )
194+ except* BaseException as other : # catches every other sub-exception
195+ logger .error ("Other grouped errors: %s" , other .exceptions )
196+ # except IntegrityError as ex:
197+ # logger.error(f"Error adding gateway: {ex}")
198+ # db.rollback()
199+ # raise GatewayError(f"Gateway already exists: {gateway.name}")
200+ # except Exception as e:
201+ # db.rollback()
202+ # raise GatewayError(f"Failed to register gateway: {str(e)}")
190203
191204 async def list_gateways (self , db : Session , include_inactive : bool = False ) -> List [GatewayRead ]:
192205 """List all registered gateways.
@@ -462,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
462475 raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
463476
464477 async def check_health_of_gateways (self , gateways : List [DbGateway ]) -> bool :
465- """Health check for gateways
478+ """Health check for a list of gateways.
479+
480+ Deactivates gateway if gateway is not healthy.
466481
467482 Args:
468- gateways: Gateways to check
483+ gateways (List[DbGateway]): List of gateways to check if healthy
469484
470485 Returns:
471- True if gateway is healthy
486+ bool: True if all active gateways are healthy
472487 """
473- for gateway in gateways :
474- if not gateway .is_active :
475- return False
488+ # Reuse a single HTTP client for all requests
489+ async with httpx .AsyncClient () as client :
490+ for gateway in gateways :
491+ # Inactive gateways are unhealthy
492+ if not gateway .is_active :
493+ continue
476494
477- try :
478- # Try to initialize connection
479- await self ._initialize_gateway (gateway .url , gateway .auth_value )
495+ try :
496+ # Ensure auth_value is a dict
497+ auth_data = gateway .auth_value or {}
498+ headers = decode_auth (auth_data )
499+
500+ # Perform the GET and raise on 4xx/5xx
501+ async with client .stream ("GET" , gateway .url , headers = headers ) as response :
502+ # This will raise immediately if status is 4xx/5xx
503+ response .raise_for_status ()
504+
505+ # Mark successful check
506+ gateway .last_seen = datetime .utcnow ()
480507
481- # Update last seen
482- gateway . last_seen = datetime . utcnow ()
483- return True
508+ except Exception :
509+ with SessionLocal () as db :
510+ await self . toggle_gateway_status ( db = db , gateway_id = gateway . id , activate = False )
484511
485- except Exception :
486- return False
512+ # All gateways passed
513+ return True
487514
488515 async def aggregate_capabilities (self , db : Session ) -> Dict [str , Any ]:
489516 """Aggregate capabilities from all gateways.
@@ -584,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
584611 raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str (e )} " )
585612
586613 def _get_active_gateways (self ) -> list [DbGateway ]:
587- """Sync function for database operations (runs in thread)."""
614+ """Sync function for database operations (runs in thread).
615+
616+ Returns:
617+ List[DbGateway]: List of active gateways
618+ """
588619 with SessionLocal () as db :
589620 return db .execute (select (DbGateway ).where (DbGateway .is_active )).scalars ().all ()
590621
@@ -598,7 +629,6 @@ async def _run_health_checks(self) -> None:
598629 if len (gateways ) > 0 :
599630 # Async health checks (non-blocking)
600631 await self .check_health_of_gateways (gateways )
601-
602632 except Exception as e :
603633 logger .error (f"Health check run failed: { str (e )} " )
604634
0 commit comments