Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 77fb812

Browse files
[NeuralChat] Support SOLAR-10.7B-Instruct-v1.0 model (#1069)
Support SOLAR-10.7B-Instruct-v1.0 model Signed-off-by: lvliang-intel <liang1.lv@intel.com>
1 parent 37d4007 commit 77fb812

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

intel_extension_for_transformers/llm/quantization/optimization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def optimize(self, model, use_llm_runtime=False):
5656
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
5757
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
5858
or re.search("starcoder", model_name, re.IGNORECASE)
59+
or re.search("solar", model_name, re.IGNORECASE)
5960
):
6061
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
6162
optimized_model = AutoModelForCausalLM.from_pretrained(

intel_extension_for_transformers/neural_chat/chatbot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def build_chatbot(config: PipelineConfig=None):
8787
elif "mistral" in config.model_name_or_path.lower():
8888
from .models.mistral_model import MistralModel
8989
adapter = MistralModel()
90+
elif "solar" in config.model_name_or_path.lower():
91+
from .models.solar_model import SolarModel
92+
adapter = SolarModel()
9093
elif "opt" in config.model_name_or_path.lower() or \
9194
"gpt" in config.model_name_or_path.lower() or \
9295
"flan-t5" in config.model_name_or_path.lower() or \
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from .base_model import BaseModel, register_model_adapter
19+
import logging
20+
from fastchat.conversation import get_conv_template, Conversation, register_conv_template, SeparatorStyle
21+
22+
logging.basicConfig(
23+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
24+
datefmt="%m/%d/%Y %H:%M:%S",
25+
level=logging.INFO,
26+
)
27+
logger = logging.getLogger(__name__)
28+
29+
30+
# Solar-10.7B Chat Template
31+
# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json
32+
register_conv_template(
33+
Conversation(
34+
name="solar",
35+
system_message="",
36+
roles=("### User", "### Assistant"),
37+
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
38+
sep="\n\n",
39+
stop_str="</s>",
40+
)
41+
)
42+
43+
class SolarModel(BaseModel):
44+
def match(self, model_path: str):
45+
"""
46+
Check if the provided model_path matches the current model.
47+
48+
Args:
49+
model_path (str): Path to a model.
50+
51+
Returns:
52+
bool: True if the model_path matches, False otherwise.
53+
"""
54+
return "solar-" in model_path.lower() and "instruct" in model_path.lower()
55+
56+
def get_default_conv_template(self, model_path: str) -> Conversation:
57+
"""
58+
Get the default conversation template for the given model path.
59+
60+
Args:
61+
model_path (str): Path to the model.
62+
63+
Returns:
64+
Conversation: A default conversation template.
65+
"""
66+
return get_conv_template("solar")
67+
68+
register_model_adapter(SolarModel)
69+

0 commit comments

Comments
 (0)