Skip to content

Commit f75a360

Browse files
author
xusenlin
committed
support glm4 for hf engine
1 parent 9e64c58 commit f75a360

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

api/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,6 @@ class LLMSettings(BaseModel):
130130
description="Use flash attention."
131131
)
132132

133-
# support for transformers.TextIteratorStreamer
134-
use_streamer_v2: Optional[bool] = Field(
135-
default=get_bool_env("USE_STREAMER_V2", "true"),
136-
description="Support for transformers.TextIteratorStreamer."
137-
)
138-
139133
interrupt_requests: Optional[bool] = Field(
140134
default=get_bool_env("INTERRUPT_REQUESTS", "true"),
141135
description="Whether to interrupt requests when a new request is received.",

api/core/default.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
generate_stream_chatglm_v3,
4343
build_qwen_chat_input,
4444
check_is_qwen,
45-
generate_stream,
45+
generate_stream_v2,
4646
build_xverse_chat_input,
4747
check_is_xverse,
4848
)
@@ -65,7 +65,6 @@ def __init__(
6565
model_name: str,
6666
context_len: Optional[int] = None,
6767
prompt_name: Optional[str] = None,
68-
use_streamer_v2: Optional[bool] = False,
6968
) -> None:
7069
"""
7170
Initialize the Default class.
@@ -76,7 +75,6 @@ def __init__(
7675
model_name (str): The name of the model.
7776
context_len (Optional[int], optional): The length of the context. Defaults to None.
7877
prompt_name (Optional[str], optional): The name of the prompt. Defaults to None.
79-
use_streamer_v2 (Optional[bool], optional): Whether to use Streamer V2. Defaults to False.
8078
"""
8179
self.model = model
8280
self.tokenizer = tokenizer
@@ -85,7 +83,6 @@ def __init__(
8583
self.model_name = model_name.lower()
8684
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
8785
self.context_len = context_len
88-
self.use_streamer_v2 = use_streamer_v2
8986

9087
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
9188

@@ -101,10 +98,11 @@ def _prepare_for_generate(self) -> None:
10198
3. Checks and constructs the prompt.
10299
4. Sets the context length if it is not already set.
103100
"""
104-
self.generate_stream_func = generate_stream
101+
self.generate_stream_func = generate_stream_v2
105102
if "chatglm3" in self.model_name:
106103
self.generate_stream_func = generate_stream_chatglm_v3
107-
self.use_streamer_v2 = False
104+
elif "chatglm4" in self.model_name:
105+
self.generate_stream_func = generate_stream_v2
108106
elif check_is_chatglm(self.model):
109107
self.generate_stream_func = generate_stream_chatglm
110108
elif check_is_qwen(self.model):
@@ -118,7 +116,10 @@ def _prepare_for_generate(self) -> None:
118116
def _check_construct_prompt(self) -> None:
119117
""" Check whether to need to construct prompts or inputs. """
120118
self.construct_prompt = self.prompt_name is not None
121-
if "chatglm3" in self.model_name:
119+
if "chatglm4" in self.model_name:
120+
self.construct_prompt = False
121+
logger.info("Using ChatGLM4 Model for Chat!")
122+
elif "chatglm3" in self.model_name:
122123
logger.info("Using ChatGLM3 Model for Chat!")
123124
elif check_is_baichuan(self.model):
124125
logger.info("Using Baichuan Model for Chat!")
@@ -246,6 +247,12 @@ def build_chat_inputs(
246247
if "chatglm3" in self.model_name:
247248
query, role = messages[-1]["content"], messages[-1]["role"]
248249
inputs = self.tokenizer.build_chat_input(query, history=messages[:-1], role=role)
250+
elif "chatglm4" in self.model_name:
251+
inputs = self.tokenizer.apply_chat_template(
252+
messages,
253+
add_generation_prompt=True,
254+
tokenize=True,
255+
)[0]
249256
elif check_is_baichuan(self.model):
250257
inputs = build_baichuan_chat_input(
251258
self.tokenizer, messages, self.context_len, max_new_tokens

api/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def create_hf_llm():
8686
model_name=SETTINGS.model_name,
8787
context_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
8888
prompt_name=SETTINGS.chat_template,
89-
use_streamer_v2=SETTINGS.use_streamer_v2,
9089
)
9190

9291

examples/chatglm3/tool_using.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
client = OpenAI(
1313
api_key="EMPTY",
14-
base_url="http://192.168.20.59:7891/v1/",
14+
base_url="http://192.168.0.59:7860/v1/",
1515
)
1616

1717
functions = list(get_tools().values())

0 commit comments

Comments
 (0)