Skip to content

Commit b94107d

Browse files
committed
feat(): Support multiple large model APIs
1 parent 3022287 commit b94107d

File tree

6 files changed

+112
-17
lines changed

6 files changed

+112
-17
lines changed

config/config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
# model list
33
model_quester_anster = "text-davinci-003"
44
model_gpt_35_turbo = "gpt-3.5-turbo"
5+
model_gpt_4o = "gpt-4o"
56
model_programming_translate = "code-davinci-002"
67

7-
# gpt key
8-
openai_api_key = "your openai key"
8+
# api 接口封装类
9+
llm_api_impl = "llm_api_default.LLMApiDefault"
910

10-
# openai api
11-
openai_baseurl = "https://api.openai.com/v1"
12-
13-
# gpt model
14-
openai_model_name = model_gpt_35_turbo
11+
# api 配置,参考LiteLLM文档:https://docs.litellm.ai/docs
12+
api_config = {
13+
"OPENAI_API_KEY": "Your OpenAI API Key",
14+
"OPENAI_API_BASE": "https://api.openai.com/v1",
15+
"MODEL_NAME": model_gpt_4o,
16+
}
1517

1618
# 2. 提示词
1719
gpt_message = """

llm_api/llm_api_default.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
from math import trunc
3+
4+
from litellm import completion
5+
from config.config import api_config as out_config
6+
from llm_api.llm_api_interface import LLMApiInterface
7+
from llm_api.load_api import create_llm_api_instance
8+
9+
10+
class LLMApiDefault(LLMApiInterface):
11+
12+
def __init__(self):
13+
self.model_name = None
14+
self.response = None
15+
16+
def set_config(self, api_config: dict) -> bool:
17+
if api_config is None:
18+
raise ValueError("api_config is None")
19+
for key in api_config:
20+
if key == "MODEL_NAME":
21+
self.model_name = api_config[key]
22+
continue
23+
os.environ[key] = api_config[key]
24+
return True
25+
26+
def generate_text(self, messages: list) -> bool:
27+
28+
self.response = completion(
29+
model=self.model_name,
30+
messages=messages,
31+
)
32+
return True
33+
34+
def get_respond_content(self) -> str:
35+
return self.response['choices'][0]['message']['content']
36+
37+
def get_respond_tokens(self) -> int:
38+
return trunc(int(self.response['usage']['total_tokens']))
39+
40+
41+
42+
43+
# 示例使用
44+
if __name__ == "__main__":
45+
api = create_llm_api_instance()
46+
api.set_config(out_config)
47+
api.generate_text([
48+
{"role": "system",
49+
"content": "你是一位作家"
50+
},
51+
{"role": "user",
52+
"content": "请写一首抒情的诗",
53+
}
54+
])
55+
print(api.get_respond_content())
56+
print(api.get_respond_tokens())
57+

llm_api/llm_api_interface.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from abc import ABC, abstractmethod
2+
3+
class LLMApiInterface(ABC):
4+
5+
@abstractmethod
6+
def set_config(self, api_config: dict) -> bool:
7+
"""设置模型配置"""
8+
pass
9+
10+
@abstractmethod
11+
def generate_text(self, messages: str) -> bool:
12+
"""根据提示生成文本"""
13+
pass
14+
15+
@abstractmethod
16+
def get_respond_content(self) -> str:
17+
"""获取模型返回内容"""
18+
pass
19+
20+
@abstractmethod
21+
def get_respond_tokens(self) -> int:
22+
"""获取模型返回token数"""
23+
pass

llm_api/load_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import importlib
2+
3+
from config.config import llm_api_impl
4+
5+
def get_llm_api_class():
6+
module_name, class_name = llm_api_impl.rsplit('.', 1)
7+
module = importlib.import_module(module_name)
8+
cls = getattr(module, class_name)
9+
return cls
10+
11+
# 使用工厂函数获取类实例
12+
def create_llm_api_instance():
13+
cls = get_llm_api_class()
14+
return cls()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ importlib-metadata==6.8.0
2020
itsdangerous==2.1.2
2121
javalang==0.13.0
2222
Jinja2==3.1.2
23+
litellm==1.58.2
2324
markdown-it-py==3.0.0
2425
MarkupSafe==2.1.3
2526
mdurl==0.1.2
2627
multidict==6.0.4
2728
netifaces==0.10.6
2829
numpy==1.26.0
29-
openai==0.27.0
30+
openai==1.59.7
3031
openpyxl==3.1.2
3132
outcome==1.2.0
3233
pandas==2.1.1

service/chat_review.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import openai
44
from openai import OpenAIError
55
from app.gitlab_utils import *
6-
from config.config import gitlab_server_url, gitlab_private_token, openai_api_key, openai_baseurl, openai_model_name
6+
from config.config import gitlab_server_url, gitlab_private_token, api_config
7+
from llm_api.load_api import create_llm_api_instance
78
from service.content_handle import filter_diff_content
89
from utils.logger import log
910
from utils.dingding import send_dingtalk_message_by_sign
@@ -45,8 +46,6 @@ def wait_and_retry(exception):
4546
def generate_review_note(change):
4647
try:
4748
content = filter_diff_content(change['diff'])
48-
openai.api_key = openai_api_key
49-
openai.api_base = openai_baseurl
5049
messages = [
5150
{"role": "system",
5251
"content": gpt_message
@@ -56,14 +55,13 @@ def generate_review_note(change):
5655
},
5756
]
5857
log.info(f"发送给gpt 内容如下:{messages}")
59-
response = openai.ChatCompletion.create(
60-
model=openai_model_name,
61-
messages=messages,
62-
)
58+
api = create_llm_api_instance()
59+
api.set_config(api_config)
60+
api.generate_text(messages)
6361
new_path = change['new_path']
6462
log.info(f'对 {new_path} review中...')
65-
response_content = response['choices'][0]['message']['content'].replace('\n\n', '\n')
66-
total_tokens = response['usage']['total_tokens']
63+
response_content = api.get_respond_content().replace('\n\n', '\n')
64+
total_tokens = api.get_respond_tokens()
6765
review_note = f'# 📚`{new_path}`' + '\n\n'
6866
review_note += f'({total_tokens} tokens) {"AI review 意见如下:"}' + '\n\n'
6967
review_note += response_content + """

0 commit comments

Comments
 (0)