Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 6fe5a9f

Browse files
authored
[NeuralChat] fix wrong output of multi-round prediction (#971)
* fix template output of multi-round predictions
1 parent 47d8755 commit 6fe5a9f

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

intel_extension_for_transformers/neural_chat/models/base_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,19 +391,19 @@ def get_conv_template(self, model_path: str, task: str = "") -> Conversation:
391391
if self.conv_template:
392392
return
393393
if not task:
394-
self.conv_template = PromptTemplate(self.get_default_conv_template(model_path).name)
394+
self.conv_template = PromptTemplate(self.get_default_conv_template(model_path).name, clear_history=True)
395395
else:
396-
clear_after_gen = True
396+
clear_history = True
397397
if task == "completion":
398398
name = "alpaca_without_input"
399399
elif task == "chat":
400400
name = "neural-chat-7b-v2"
401-
clear_after_gen = False
401+
clear_history = False
402402
elif task == "summarization":
403403
name = "summarization"
404404
else:
405405
raise NotImplementedError(f"Unsupported task {task}.")
406-
self.conv_template = PromptTemplate(name, clear_after_gen=clear_after_gen)
406+
self.conv_template = PromptTemplate(name, clear_history=clear_history)
407407

408408
def prepare_prompt(self, prompt: str, model_path: str, task: str = ""):
409409
self.get_conv_template(model_path, task)

intel_extension_for_transformers/neural_chat/prompts/prompt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@
202202
)
203203

204204
class PromptTemplate:
205-
def __init__(self, name="one_shot", clear_after_gen=False):
205+
def __init__(self, name="one_shot", clear_history=False):
206206
self.conv = get_conv_template(name)
207-
self.clear_after_gen = clear_after_gen
207+
self.clear_history = clear_history
208208

209209
@property
210210
def roles(self):
@@ -215,7 +215,7 @@ def append_message(self, role: str, message: str):
215215

216216
def get_prompt(self) -> str:
217217
res = self.conv.get_prompt()
218-
if self.clear_after_gen:
218+
if self.clear_history:
219219
self.clear_messages()
220220
return res
221221

0 commit comments

Comments
 (0)