1- from typing import Annotated
1+ import threading
2+ from collections import defaultdict
3+ from typing import Annotated , Optional
24from fastapi import FastAPI , HTTPException , Request
35from fastapi .responses import Response , StreamingResponse
46from starlette .concurrency import run_in_threadpool
2830# 同步 Engine
2931engine : Engine = None
3032
33+ # 1. 存储每个 model_name 的下一个索引
34+ # defaultdict(int) 会在键不存在时自动创建并赋值为 0
35+ model_rr_counters = defaultdict (int )
36+ # 2. 为每个 model_name 提供一个独立的锁,以防止并发访问 rr_counters 时的竞争条件
37+ # defaultdict(threading.Lock) 会在键不存在时自动创建一个新的锁
38+ model_rr_locks = defaultdict (threading .Lock )
39+
3140
3241# --- SQLModel 数据模型 ---
3342class ModelRoute (SQLModel , table = True ):
3443 """
3544 一个 SQLModel 模型,用于存储模型名称到后端 URL 的路由。
36- 'model_name ' 是主键,确保了唯一性 。
45+ 'id ' 是主键,允许 'model_name' 重复以实现负载均衡 。
3746 """
3847
39- model_name : str = Field (primary_key = True , index = True )
48+ # <--- MODIFIED: 架构变更 --->
49+ # 1. 'model_name' 不再是主键。
50+ # 2. 我们添加了一个自动递增的 'id' 作为主键。
51+ id : Optional [int ] = Field (default = None , primary_key = True )
52+ model_name : str = Field (index = True ) # 仍然索引 'model_name' 以提高查询速度
4053 model_url : str
4154 api_key : str | None = Field (default = None )
4255 created : datetime = Field (
@@ -75,14 +88,51 @@ def get_all_routes_sync():
7588def get_routing_info_sync (model : str ):
7689 """同步地从数据库中查询模型和所有可用模型"""
7790 with Session (engine ) as session :
78- # 1. 获取特定模型
79- db_route = session .get (ModelRoute , model )
91+ # <--- MODIFIED: 核心负载均衡逻辑 --->
92+
93+ # 1. 获取 *所有* 匹配该模型的路由
94+ statement = select (ModelRoute ).where (ModelRoute .model_name == model )
95+ # .all() 将所有匹配的路由加载到列表中
96+ db_routes_list = session .exec (statement ).all ()
8097
81- # 2. 获取所有可用模型
98+ # 2. 获取所有可用模型名称 (用于错误消息)
8299 available_models = get_all_model_names_sync ()
83- server = db_route .model_url if db_route else None
84- # 3. 获取后端 API 密钥
85- backend_api_key = db_route .api_key if db_route else None
100+
101+ if not db_routes_list :
102+ # 3. 如果列表为空,则未找到模型
103+ server = None
104+ backend_api_key = None
105+ elif len (db_routes_list ) == 1 :
106+ # 4. 优化:如果只有一个,直接返回,无需加锁或计数
107+ selected_route = db_routes_list [0 ]
108+ server = selected_route .model_url
109+ backend_api_key = selected_route .api_key
110+ else :
111+ # 5. 轮询 (Round-Robin) 逻辑
112+ # 获取此模型的特定锁
113+ lock = model_rr_locks [model ]
114+
115+ # 使用 'with' 语句自动获取和释放锁(线程安全)
116+ with lock :
117+ # 5.1. 获取当前计数器值 (默认为 0)
118+ current_index = model_rr_counters [model ]
119+
120+ # 5.2. 计算下一个索引,使用模运算实现循环
121+ next_index = (current_index + 1 ) % len (db_routes_list )
122+
123+ # 5.3. 更新全局计数器,为 *下一个* 请求做准备
124+ model_rr_counters [model ] = next_index
125+
126+ # 5.4. 使用我们在此次请求中获取的 current_index 来选择路由
127+ # (在锁之外执行,因为我们已经拿到了索引)
128+ selected_route = db_routes_list [current_index ]
129+
130+ server = selected_route .model_url
131+ backend_api_key = selected_route .api_key
132+
133+ logger .debug (
134+ f"Round-Robin: model={ model } , selected_index={ current_index } , total={ len (db_routes_list )} "
135+ )
86136
87137 return server , available_models , backend_api_key
88138
@@ -145,6 +195,9 @@ async def lifespan(app: FastAPI):
145195 # 捕获任何其他意外的检查错误
146196 logger .error (f"数据库 schema 检查期间发生意外错误: { e } " )
147197 logger .warning ("将尝试继续,但可能会失败。" )
198+ import traceback
199+
200+ traceback .print_exc ()
148201 engine = temp_engine
149202 # ---------- 检查数据库 schema 是否与 SQLModel 定义匹配 ----------
150203
@@ -286,19 +339,33 @@ async def _non_stream_proxy(
286339async def list_models ():
287340 """
288341 OpenAI 兼容接口: 列出所有可用的模型。
342+ (此函数必须修改以处理重复的模型名称)
289343 """
290344 try :
345+
346+ # 1. 获取所有路由条目
291347 all_routes : list [ModelRoute ] = await run_in_threadpool (get_all_routes_sync )
292348
293- models_data = []
349+ # 2. 按 model_name 分组,并找到每组中 *最早* 的 created 时间戳
350+ models_by_name = {} # { model_name: earliest_timestamp }
294351 for route in all_routes :
295352 created_timestamp = int (route .created .timestamp ())
296353
354+ if route .model_name not in models_by_name :
355+ models_by_name [route .model_name ] = created_timestamp
356+ else :
357+ # 更新为更早(更小)的时间戳
358+ models_by_name [route .model_name ] = min (
359+ models_by_name [route .model_name ], created_timestamp
360+ )
361+
362+ # 3. 构建唯一的模型数据列表
363+ models_data = []
364+ for model_name , created_timestamp in models_by_name .items ():
297365 models_data .append (
298366 {
299- "id" : route . model_name ,
367+ "id" : model_name ,
300368 "object" : "model" ,
301- # 3. 使用数据库中的 created 时间戳
302369 "created" : created_timestamp ,
303370 "owned_by" : "openai_router" ,
304371 "permission" : [],
@@ -307,7 +374,9 @@ async def list_models():
307374
308375 response_data = {"object" : "list" , "data" : models_data }
309376
310- logger .info (f"Returning { len (all_routes )} available models for /v1/models." )
377+ logger .info (
378+ f"Returning { len (models_data )} unique available models for /v1/models."
379+ )
311380
312381 return response_data
313382
@@ -364,17 +433,23 @@ def add_or_update_route_sync(model_name: str, model_url: str, api_key: str | Non
364433 api_key = None
365434
366435 with Session (engine ) as session :
367- db_route = session .get (ModelRoute , model_name )
436+ statement = select (ModelRoute ).where (
437+ ModelRoute .model_name == model_name , ModelRoute .model_url == model_url
438+ )
439+ db_route = session .exec (statement ).first () # 使用 .first()
368440
369441 if db_route :
370- db_route . model_url = model_url
371- db_route .api_key = api_key # 更新 API 密钥
372- status_message = f"路由 '{ model_name } ' 已更新 。"
442+ # 2. 如果存在,我们只更新 API 密钥
443+ db_route .api_key = api_key
444+ status_message = f"路由 '{ model_name } -> { model_url } ' 的 API 密钥已更新 。"
373445 else :
446+ # 3. 如果不存在,我们创建一个全新的条目
374447 db_route = ModelRoute (
375448 model_name = model_name , model_url = model_url , api_key = api_key
376449 )
377- status_message = f"路由 '{ model_name } ' 已添加。"
450+ status_message = (
451+ f"新路由 '{ model_name } -> { model_url } ' 已添加 (用于负载均衡)。"
452+ )
378453
379454 session .add (db_route )
380455 session .commit ()
@@ -383,19 +458,27 @@ def add_or_update_route_sync(model_name: str, model_url: str, api_key: str | Non
383458 return status_message
384459
385460
386- def delete_route_sync (model_name : str ):
387- """同步从数据库删除一个路由"""
461+ def delete_route_sync (model_name : str , model_url : str ): # <--- MODIFIED: 需要 model_url
462+ """
463+ 同步从数据库删除一个 *特定* 的 (model_name, model_url) 路由。
464+ """
388465 status_message = ""
389466 with Session (engine ) as session :
390- db_route = session .get (ModelRoute , model_name )
467+ # <--- MODIFIED: 删除逻辑 --->
468+ # 1. 寻找一个 (model_name, model_url) 组合的条目
469+ statement = select (ModelRoute ).where (
470+ ModelRoute .model_name == model_name , ModelRoute .model_url == model_url
471+ )
472+ db_route = session .exec (statement ).first ()
473+ # <--- END MODIFIED --->
391474
392475 if db_route :
393476 session .delete (db_route )
394477 session .commit ()
395- status_message = f"路由 '{ model_name } ' 已删除。"
396- logger .info (f"[Admin] Route deleted: { model_name } " )
478+ status_message = f"路由 '{ model_name } -> { model_url } ' 已删除。"
479+ logger .info (f"[Admin] Route deleted: { model_name } -> { model_url } " )
397480 else :
398- status_message = f"错误: 未找到路由 '{ model_name } '。"
481+ status_message = f"错误: 未找到路由 '{ model_name } -> { model_url } '。"
399482
400483 return status_message
401484
@@ -416,12 +499,14 @@ async def add_or_update_route(model_name: str, model_url: str, api_key: str | No
416499 return status_message , await get_current_routes ()
417500
418501
419- async def delete_route (model_name : str ):
502+ async def delete_route (
503+ model_name : str , model_url : str
504+ ): # <--- MODIFIED: 需要 model_url
420505 """异步调用同步函数删除路由"""
421- if not model_name :
422- return "要删除的模型名称不能为空 " , await get_current_routes ()
506+ if not model_name or not model_url :
507+ return "要删除的模型名称和 URL 均不能为空 " , await get_current_routes ()
423508
424- status_message = await run_in_threadpool (delete_route_sync , model_name )
509+ status_message = await run_in_threadpool (delete_route_sync , model_name , model_url )
425510 return status_message , await get_current_routes ()
426511
427512
@@ -465,7 +550,7 @@ def create_admin_ui():
465550 "后端 URL (Backend URL)" ,
466551 "API 密钥 (API Key)" ,
467552 ],
468- label = "当前路由表" ,
553+ label = "当前路由表 (同一模型可有多个URL) " ,
469554 row_count = (1 , "fixed" ),
470555 col_count = (3 , "fixed" ),
471556 interactive = False ,
@@ -489,7 +574,7 @@ def create_admin_ui():
489574 with gr .Row ():
490575 add_update_button = gr .Button ("添加 / 更新" )
491576 delete_button = gr .Button (
492- "删除" ,
577+ "删除 (指定URL) " ,
493578 variant = "stop" ,
494579 )
495580
@@ -505,7 +590,7 @@ def create_admin_ui():
505590
506591 delete_button .click (
507592 delete_route ,
508- inputs = [model_name_input ],
593+ inputs = [model_name_input , model_url_input ],
509594 outputs = [status_output , routes_datagrid ],
510595 )
511596
0 commit comments