Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 9729b6a

Browse files
[NeuralChat] Support Mixtral-8x7B-v0.1 model (#972)
* Support Mixstral-8x7b model Signed-off-by: lvliang-intel <liang1.lv@intel.com>
1 parent 1a2afa9 commit 9729b6a

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

intel_extension_for_transformers/neural_chat/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ The table below displays the validated model list in NeuralChat for both inferen
144144
|LLaMA2 series|||||
145145
|MPT series|||||
146146
|Mistral|||||
147+
|Mixtral-8x7b-v0.1|||||
147148
|ChatGLM series|||||
148149
|Qwen series|||||
149150
|StarCoder series| | | ||

intel_extension_for_transformers/neural_chat/chatbot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def build_chatbot(config: PipelineConfig=None):
9696
"bloom" in config.model_name_or_path.lower() or \
9797
"starcoder" in config.model_name_or_path.lower() or \
9898
"codegen" in config.model_name_or_path.lower() or \
99-
"magicoder" in config.model_name_or_path.lower():
99+
"magicoder" in config.model_name_or_path.lower() or \
100+
"mixtral" in config.model_name_or_path.lower():
100101
from .models.base_model import BaseModel
101102
adapter = BaseModel()
102103
else:

intel_extension_for_transformers/neural_chat/models/model_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def load_model(
500500
or config.model_type == "mpt"
501501
or config.model_type == "llama"
502502
or config.model_type == "mistral"
503+
or config.model_type == "mixtral"
503504
) and not ipex_int8) or config.model_type == "opt":
504505
with smart_context_manager(use_deepspeed=use_deepspeed):
505506
model = AutoModelForCausalLM.from_pretrained(
@@ -554,7 +555,7 @@ def load_model(
554555
)
555556
else:
556557
raise ValueError(f"unsupported model name or path {model_name}, \
557-
only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER/CODEGEN now.")
558+
only supports t5/llama/mpt/gptj/bloom/opt/qwen/mistral/mixtral/gpt_bigcode model type now.")
558559
except EnvironmentError as e:
559560
logging.error(f"Exception: {e}")
560561
if "not a local folder and is not a valid model identifier" in str(e):

0 commit comments

Comments
 (0)