Skip to content

Commit f7dd44c

Browse files
committed
实现基于 轮询 的负载均衡
1 parent d41d643 commit f7dd44c

File tree

1 file changed

+116
-31
lines changed

1 file changed

+116
-31
lines changed

src/openai_router/main.py

Lines changed: 116 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Annotated
1+
import threading
2+
from collections import defaultdict
3+
from typing import Annotated, Optional
24
from fastapi import FastAPI, HTTPException, Request
35
from fastapi.responses import Response, StreamingResponse
46
from starlette.concurrency import run_in_threadpool
@@ -28,15 +30,26 @@
2830
# 同步 Engine
2931
engine: Engine = None
3032

33+
# 1. 存储每个 model_name 的下一个索引
34+
# defaultdict(int) 会在键不存在时自动创建并赋值为 0
35+
model_rr_counters = defaultdict(int)
36+
# 2. 为每个 model_name 提供一个独立的锁,以防止并发访问 rr_counters 时的竞争条件
37+
# defaultdict(threading.Lock) 会在键不存在时自动创建一个新的锁
38+
model_rr_locks = defaultdict(threading.Lock)
39+
3140

3241
# --- SQLModel 数据模型 ---
3342
class ModelRoute(SQLModel, table=True):
3443
"""
3544
一个 SQLModel 模型,用于存储模型名称到后端 URL 的路由。
36-
'model_name' 是主键,确保了唯一性
45+
'id' 是主键,允许 'model_name' 重复以实现负载均衡
3746
"""
3847

39-
model_name: str = Field(primary_key=True, index=True)
48+
# <--- MODIFIED: 架构变更 --->
49+
# 1. 'model_name' 不再是主键。
50+
# 2. 我们添加了一个自动递增的 'id' 作为主键。
51+
id: Optional[int] = Field(default=None, primary_key=True)
52+
model_name: str = Field(index=True) # 仍然索引 'model_name' 以提高查询速度
4053
model_url: str
4154
api_key: str | None = Field(default=None)
4255
created: datetime = Field(
@@ -75,14 +88,51 @@ def get_all_routes_sync():
7588
def get_routing_info_sync(model: str):
7689
"""同步地从数据库中查询模型和所有可用模型"""
7790
with Session(engine) as session:
78-
# 1. 获取特定模型
79-
db_route = session.get(ModelRoute, model)
91+
# <--- MODIFIED: 核心负载均衡逻辑 --->
92+
93+
# 1. 获取 *所有* 匹配该模型的路由
94+
statement = select(ModelRoute).where(ModelRoute.model_name == model)
95+
# .all() 将所有匹配的路由加载到列表中
96+
db_routes_list = session.exec(statement).all()
8097

81-
# 2. 获取所有可用模型
98+
# 2. 获取所有可用模型名称 (用于错误消息)
8299
available_models = get_all_model_names_sync()
83-
server = db_route.model_url if db_route else None
84-
# 3. 获取后端 API 密钥
85-
backend_api_key = db_route.api_key if db_route else None
100+
101+
if not db_routes_list:
102+
# 3. 如果列表为空,则未找到模型
103+
server = None
104+
backend_api_key = None
105+
elif len(db_routes_list) == 1:
106+
# 4. 优化:如果只有一个,直接返回,无需加锁或计数
107+
selected_route = db_routes_list[0]
108+
server = selected_route.model_url
109+
backend_api_key = selected_route.api_key
110+
else:
111+
# 5. 轮询 (Round-Robin) 逻辑
112+
# 获取此模型的特定锁
113+
lock = model_rr_locks[model]
114+
115+
# 使用 'with' 语句自动获取和释放锁(线程安全)
116+
with lock:
117+
# 5.1. 获取当前计数器值 (默认为 0)
118+
current_index = model_rr_counters[model]
119+
120+
# 5.2. 计算下一个索引,使用模运算实现循环
121+
next_index = (current_index + 1) % len(db_routes_list)
122+
123+
# 5.3. 更新全局计数器,为 *下一个* 请求做准备
124+
model_rr_counters[model] = next_index
125+
126+
# 5.4. 使用我们在此次请求中获取的 current_index 来选择路由
127+
# (在锁之外执行,因为我们已经拿到了索引)
128+
selected_route = db_routes_list[current_index]
129+
130+
server = selected_route.model_url
131+
backend_api_key = selected_route.api_key
132+
133+
logger.debug(
134+
f"Round-Robin: model={model}, selected_index={current_index}, total={len(db_routes_list)}"
135+
)
86136

87137
return server, available_models, backend_api_key
88138

@@ -145,6 +195,9 @@ async def lifespan(app: FastAPI):
145195
# 捕获任何其他意外的检查错误
146196
logger.error(f"数据库 schema 检查期间发生意外错误: {e}")
147197
logger.warning("将尝试继续,但可能会失败。")
198+
import traceback
199+
200+
traceback.print_exc()
148201
engine = temp_engine
149202
# ---------- 检查数据库 schema 是否与 SQLModel 定义匹配 ----------
150203

@@ -286,19 +339,33 @@ async def _non_stream_proxy(
286339
async def list_models():
287340
"""
288341
OpenAI 兼容接口: 列出所有可用的模型。
342+
(此函数必须修改以处理重复的模型名称)
289343
"""
290344
try:
345+
346+
# 1. 获取所有路由条目
291347
all_routes: list[ModelRoute] = await run_in_threadpool(get_all_routes_sync)
292348

293-
models_data = []
349+
# 2. 按 model_name 分组,并找到每组中 *最早* 的 created 时间戳
350+
models_by_name = {} # { model_name: earliest_timestamp }
294351
for route in all_routes:
295352
created_timestamp = int(route.created.timestamp())
296353

354+
if route.model_name not in models_by_name:
355+
models_by_name[route.model_name] = created_timestamp
356+
else:
357+
# 更新为更早(更小)的时间戳
358+
models_by_name[route.model_name] = min(
359+
models_by_name[route.model_name], created_timestamp
360+
)
361+
362+
# 3. 构建唯一的模型数据列表
363+
models_data = []
364+
for model_name, created_timestamp in models_by_name.items():
297365
models_data.append(
298366
{
299-
"id": route.model_name,
367+
"id": model_name,
300368
"object": "model",
301-
# 3. 使用数据库中的 created 时间戳
302369
"created": created_timestamp,
303370
"owned_by": "openai_router",
304371
"permission": [],
@@ -307,7 +374,9 @@ async def list_models():
307374

308375
response_data = {"object": "list", "data": models_data}
309376

310-
logger.info(f"Returning {len(all_routes)} available models for /v1/models.")
377+
logger.info(
378+
f"Returning {len(models_data)} unique available models for /v1/models."
379+
)
311380

312381
return response_data
313382

@@ -364,17 +433,23 @@ def add_or_update_route_sync(model_name: str, model_url: str, api_key: str | Non
364433
api_key = None
365434

366435
with Session(engine) as session:
367-
db_route = session.get(ModelRoute, model_name)
436+
statement = select(ModelRoute).where(
437+
ModelRoute.model_name == model_name, ModelRoute.model_url == model_url
438+
)
439+
db_route = session.exec(statement).first() # 使用 .first()
368440

369441
if db_route:
370-
db_route.model_url = model_url
371-
db_route.api_key = api_key # 更新 API 密钥
372-
status_message = f"路由 '{model_name}' 已更新。"
442+
# 2. 如果存在,我们只更新 API 密钥
443+
db_route.api_key = api_key
444+
status_message = f"路由 '{model_name} -> {model_url}' 的 API 密钥已更新。"
373445
else:
446+
# 3. 如果不存在,我们创建一个全新的条目
374447
db_route = ModelRoute(
375448
model_name=model_name, model_url=model_url, api_key=api_key
376449
)
377-
status_message = f"路由 '{model_name}' 已添加。"
450+
status_message = (
451+
f"新路由 '{model_name} -> {model_url}' 已添加 (用于负载均衡)。"
452+
)
378453

379454
session.add(db_route)
380455
session.commit()
@@ -383,19 +458,27 @@ def add_or_update_route_sync(model_name: str, model_url: str, api_key: str | Non
383458
return status_message
384459

385460

386-
def delete_route_sync(model_name: str):
387-
"""同步从数据库删除一个路由"""
461+
def delete_route_sync(model_name: str, model_url: str): # <--- MODIFIED: 需要 model_url
462+
"""
463+
同步从数据库删除一个 *特定* 的 (model_name, model_url) 路由。
464+
"""
388465
status_message = ""
389466
with Session(engine) as session:
390-
db_route = session.get(ModelRoute, model_name)
467+
# <--- MODIFIED: 删除逻辑 --->
468+
# 1. 寻找一个 (model_name, model_url) 组合的条目
469+
statement = select(ModelRoute).where(
470+
ModelRoute.model_name == model_name, ModelRoute.model_url == model_url
471+
)
472+
db_route = session.exec(statement).first()
473+
# <--- END MODIFIED --->
391474

392475
if db_route:
393476
session.delete(db_route)
394477
session.commit()
395-
status_message = f"路由 '{model_name}' 已删除。"
396-
logger.info(f"[Admin] Route deleted: {model_name}")
478+
status_message = f"路由 '{model_name} -> {model_url}' 已删除。"
479+
logger.info(f"[Admin] Route deleted: {model_name} -> {model_url}")
397480
else:
398-
status_message = f"错误: 未找到路由 '{model_name}'。"
481+
status_message = f"错误: 未找到路由 '{model_name} -> {model_url}'。"
399482

400483
return status_message
401484

@@ -416,12 +499,14 @@ async def add_or_update_route(model_name: str, model_url: str, api_key: str | No
416499
return status_message, await get_current_routes()
417500

418501

419-
async def delete_route(model_name: str):
502+
async def delete_route(
503+
model_name: str, model_url: str
504+
): # <--- MODIFIED: 需要 model_url
420505
"""异步调用同步函数删除路由"""
421-
if not model_name:
422-
return "要删除的模型名称不能为空", await get_current_routes()
506+
if not model_name or not model_url:
507+
return "要删除的模型名称和 URL 均不能为空", await get_current_routes()
423508

424-
status_message = await run_in_threadpool(delete_route_sync, model_name)
509+
status_message = await run_in_threadpool(delete_route_sync, model_name, model_url)
425510
return status_message, await get_current_routes()
426511

427512

@@ -465,7 +550,7 @@ def create_admin_ui():
465550
"后端 URL (Backend URL)",
466551
"API 密钥 (API Key)",
467552
],
468-
label="当前路由表",
553+
label="当前路由表 (同一模型可有多个URL)",
469554
row_count=(1, "fixed"),
470555
col_count=(3, "fixed"),
471556
interactive=False,
@@ -489,7 +574,7 @@ def create_admin_ui():
489574
with gr.Row():
490575
add_update_button = gr.Button("添加 / 更新")
491576
delete_button = gr.Button(
492-
"删除",
577+
"删除 (指定URL)",
493578
variant="stop",
494579
)
495580

@@ -505,7 +590,7 @@ def create_admin_ui():
505590

506591
delete_button.click(
507592
delete_route,
508-
inputs=[model_name_input],
593+
inputs=[model_name_input, model_url_input],
509594
outputs=[status_output, routes_datagrid],
510595
)
511596

0 commit comments

Comments
 (0)