4242 generate_stream_chatglm_v3 ,
4343 build_qwen_chat_input ,
4444 check_is_qwen ,
45- generate_stream ,
45+ generate_stream_v2 ,
4646 build_xverse_chat_input ,
4747 check_is_xverse ,
4848)
@@ -65,7 +65,6 @@ def __init__(
6565 model_name : str ,
6666 context_len : Optional [int ] = None ,
6767 prompt_name : Optional [str ] = None ,
68- use_streamer_v2 : Optional [bool ] = False ,
6968 ) -> None :
7069 """
7170 Initialize the Default class.
@@ -76,7 +75,6 @@ def __init__(
7675 model_name (str): The name of the model.
7776 context_len (Optional[int], optional): The length of the context. Defaults to None.
7877 prompt_name (Optional[str], optional): The name of the prompt. Defaults to None.
79- use_streamer_v2 (Optional[bool], optional): Whether to use Streamer V2. Defaults to False.
8078 """
8179 self .model = model
8280 self .tokenizer = tokenizer
@@ -85,7 +83,6 @@ def __init__(
8583 self .model_name = model_name .lower ()
8684 self .prompt_name = prompt_name .lower () if prompt_name is not None else None
8785 self .context_len = context_len
88- self .use_streamer_v2 = use_streamer_v2
8986
9087 self .prompt_adapter = get_prompt_adapter (self .model_name , prompt_name = self .prompt_name )
9188
@@ -101,10 +98,11 @@ def _prepare_for_generate(self) -> None:
10198 3. Checks and constructs the prompt.
10299 4. Sets the context length if it is not already set.
103100 """
104- self .generate_stream_func = generate_stream
101+ self .generate_stream_func = generate_stream_v2
105102 if "chatglm3" in self .model_name :
106103 self .generate_stream_func = generate_stream_chatglm_v3
107- self .use_streamer_v2 = False
104+ elif "chatglm4" in self .model_name :
105+ self .generate_stream_func = generate_stream_v2
108106 elif check_is_chatglm (self .model ):
109107 self .generate_stream_func = generate_stream_chatglm
110108 elif check_is_qwen (self .model ):
@@ -118,7 +116,10 @@ def _prepare_for_generate(self) -> None:
118116 def _check_construct_prompt (self ) -> None :
119117 """ Check whether to need to construct prompts or inputs. """
120118 self .construct_prompt = self .prompt_name is not None
121- if "chatglm3" in self .model_name :
119+ if "chatglm4" in self .model_name :
120+ self .construct_prompt = False
121+ logger .info ("Using ChatGLM4 Model for Chat!" )
122+ elif "chatglm3" in self .model_name :
122123 logger .info ("Using ChatGLM3 Model for Chat!" )
123124 elif check_is_baichuan (self .model ):
124125 logger .info ("Using Baichuan Model for Chat!" )
@@ -246,6 +247,12 @@ def build_chat_inputs(
246247 if "chatglm3" in self .model_name :
247248 query , role = messages [- 1 ]["content" ], messages [- 1 ]["role" ]
248249 inputs = self .tokenizer .build_chat_input (query , history = messages [:- 1 ], role = role )
250+ elif "chatglm4" in self .model_name :
251+ inputs = self .tokenizer .apply_chat_template (
252+ messages ,
253+ add_generation_prompt = True ,
254+ tokenize = True ,
255+ )[0 ]
249256 elif check_is_baichuan (self .model ):
250257 inputs = build_baichuan_chat_input (
251258 self .tokenizer , messages , self .context_len , max_new_tokens
0 commit comments