1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717import os
18- from transformers import AutoConfig , AutoTokenizer
19- from intel_extension_for_transformers .llm .runtime .graph .scripts .convert import convert_model
18+
2019import torch
20+ from intel_extension_for_transformers .llm .runtime .graph .scripts .convert import convert_model
21+ from transformers import AutoConfig , AutoTokenizer
22+
2123model_maps = {"gpt_neox" : "gptneox" , "gpt_bigcode" : "starcoder" }
2224
25+
2326class Model :
2427 def __init__ (self ):
2528 self .module = None
@@ -28,55 +31,68 @@ def __init__(self):
2831 self .bin_file = None
2932 self .generate_round = 0
3033
31- def __import_package (self , model_name ):
34+ def __import_package (self , model_type ):
3235 if self .module :
3336 return
34- if model_name == "gptj" :
37+ if model_type == "gptj" :
3538 import intel_extension_for_transformers .llm .runtime .graph .gptj_cpp as cpp_model
36- elif model_name == "falcon" :
39+ elif model_type == "falcon" :
3740 import intel_extension_for_transformers .llm .runtime .graph .falcon_cpp as cpp_model
38- elif model_name == "gptneox" :
41+ elif model_type == "gptneox" :
3942 import intel_extension_for_transformers .llm .runtime .graph .gptneox_cpp as cpp_model
40- elif model_name == "dolly" :
43+ elif model_type == "dolly" :
4144 import intel_extension_for_transformers .llm .runtime .graph .dolly_cpp as cpp_model
42- elif model_name == "llama" or model_name == "llama2" :
45+ elif model_type == "llama" or model_type == "llama2" :
4346 import intel_extension_for_transformers .llm .runtime .graph .llama_cpp as cpp_model
44- elif model_name == "mpt" :
47+ elif model_type == "mpt" :
4548 import intel_extension_for_transformers .llm .runtime .graph .mpt_cpp as cpp_model
46- elif model_name == "gpt_bigcode" or model_name == "starcoder" :
49+ elif model_type == "gpt_bigcode" or model_type == "starcoder" :
4750 import intel_extension_for_transformers .llm .runtime .graph .starcoder_cpp as cpp_model
48- elif model_name == "opt" :
51+ elif model_type == "opt" :
4952 import intel_extension_for_transformers .llm .runtime .graph .opt_cpp as cpp_model
50- elif model_name == "bloom" :
53+ elif model_type == "bloom" :
5154 import intel_extension_for_transformers .llm .runtime .graph .bloom_cpp as cpp_model
52- elif model_name == "chatglm" :
55+ elif model_type == "chatglm" :
5356 import intel_extension_for_transformers .llm .runtime .graph .chatglm_cpp as cpp_model
54- elif model_name == "chatglm2" :
57+ elif model_type == "chatglm2" :
5558 import intel_extension_for_transformers .llm .runtime .graph .chatglm2_cpp as cpp_model
56- elif model_name == "baichuan" :
59+ elif model_type == "baichuan" :
5760 import intel_extension_for_transformers .llm .runtime .graph .baichuan_cpp as cpp_model
58- elif model_name == "polyglot" :
61+ elif model_type == "polyglot" :
5962 import intel_extension_for_transformers .llm .runtime .graph .polyglot_cpp as cpp_model
60- elif model_name == "mistral" :
63+ elif model_type == "mistral" :
6164 import intel_extension_for_transformers .llm .runtime .graph .mistral_cpp as cpp_model
6265 else :
63- raise TypeError ("Unspported model type {}!" .format (model_name ))
66+ raise TypeError ("Unspported model type {}!" .format (model_type ))
6467 self .module = cpp_model
6568
69+ @staticmethod
70+ def get_model_type (model_config ):
71+ model_type = model_maps .get (model_config .model_type , model_config .model_type )
72+ if model_type == "chatglm" and "chatglm2" in model_config ._name_or_path :
73+ model_type = "chatglm2"
74+ return model_type
75+
6676 def init (self , model_name , not_quant = False , use_cache = False , ** quant_kwargs ):
6777 self .config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
6878 self .tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = True )
69- model_type = model_maps .get (self .config .model_type , self .config .model_type )
70- if model_type == "chatglm" and "chatglm2" in self .config ._name_or_path :
71- model_type = "chatglm2"
79+ model_type = Model .get_model_type (self .config )
7280 self .__import_package (model_type )
7381
7482 # check cache and quantization
7583 output_path = "runtime_outs"
76- if not os .path .exists (output_path ):
77- os .makedirs (output_path )
84+ os .makedirs (output_path , exist_ok = True )
7885 fp32_bin = "{}/ne_{}_f32.bin" .format (output_path , model_type )
79- quant_bin = "{}/ne_{}_q.bin" .format (output_path , model_type )
86+ quant_desc = quant_kwargs ['weight_dtype' ]
87+ if quant_kwargs ['use_ggml' ]:
88+ quant_desc += "_ggml"
89+ else :
90+ quant_desc += "_jblas_c" + quant_kwargs ['compute_dtype' ]
91+ if quant_kwargs ['group_size' ] == - 1 :
92+ quant_desc += "_pc"
93+ else :
94+ quant_desc += "_g{}" .format (quant_kwargs ['group_size' ])
95+ quant_bin = "{}/ne_{}_q_{}.bin" .format (output_path , model_type , quant_desc )
8096
8197 if not_quant :
8298 self .bin_file = fp32_bin
@@ -85,20 +101,22 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
85101 if use_cache and os .path .exists (self .bin_file ):
86102 return
87103
88- convert_model (model_name , fp32_bin , "f32" )
89- assert os .path .exists (fp32_bin ), "Fail to convert pytorch model"
104+ if not use_cache or not os .path .exists (fp32_bin ):
105+ convert_model (model_name , fp32_bin , "f32" )
106+ assert os .path .exists (fp32_bin ), "Fail to convert pytorch model"
90107
91108 if not_quant :
92109 print ("FP32 model will be used." )
93110 return
94- self .module .Model .quant_model (model_path = fp32_bin , out_path = quant_bin , ** quant_kwargs )
111+ self .module .Model .quant_model (model_path = fp32_bin , out_path = quant_bin , ** quant_kwargs )
95112 assert os .path .exists (quant_bin ), "Fail to quantize model"
96-
113+
97114 # clean
98- os .remove (fp32_bin )
115+ if not use_cache :
116+ os .remove (fp32_bin )
99117
100- def init_from_bin (self , model_name , model_path , ** generate_kwargs ):
101- self .__import_package (model_name )
118+ def init_from_bin (self , model_type , model_path , ** generate_kwargs ):
119+ self .__import_package (model_type )
102120 self .model = self .module .Model ()
103121 if "threads" not in generate_kwargs :
104122 threads = os .getenv ("OMP_NUM_THREADS" )
@@ -108,11 +126,9 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs):
108126 generate_kwargs ["threads" ] = int (threads )
109127 self .model .init_model (model_path , ** generate_kwargs )
110128
111- def quant_model (self , model_name , model_path , out_path , ** quant_kwargs ):
112- self .__import_package (model_name )
113- self .module .Model .quant_model (model_path = model_path ,
114- out_path = out_path , ** quant_kwargs )
115-
129+ def quant_model (self , model_type , model_path , out_path , ** quant_kwargs ):
130+ self .__import_package (model_type )
131+ self .module .Model .quant_model (model_path = model_path , out_path = out_path , ** quant_kwargs )
116132
117133 def generate (self , input_ids , streamer = None , interactive = False , ignore_prompt = False , stopping_criteria = None , ** generate_kwargs ):
118134 max_new_tokens = generate_kwargs .get ("max_new_tokens" , - 1 )
@@ -129,8 +145,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
129145 ret = input_ids .tolist ()
130146
131147 beam_search = False
132- if ("num_beams" in generate_kwargs and generate_kwargs ["num_beams" ] > 1 ) and not \
133- generate_kwargs .get ("do_sample" , False ):
148+ if (generate_kwargs .get ("num_beams" , 1 ) > 1 ) and not generate_kwargs .get ("do_sample" , False ):
134149 beam_search = True
135150 if not beam_search :
136151 # TODO support multi batch
@@ -142,30 +157,43 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
142157 Make sure that `num_beams` is set to 1."
143158 if self .generate_round == 0 and not ignore_prompt :
144159 streamer .put (input_ids )
145-
160+
146161 if interactive :
147162 self .model .reset_token_end ()
148163 out_count = 0
164+ input_list = input_ids .tolist ()
149165 while True :
150- response = self .model .generate (input_ids = input_ids .tolist ())
166+ response = self .model .generate (input_ids = input_list )
167+ input_list = [] # next-token stage will use previous output
151168 if len (response ) == 0 :
152169 break
153170 if streamer :
154171 streamer .put (torch .tensor ([response [0 ]]))
155172 for i in range (len (response )):
156173 ret [i ].extend (response [i ])
174+ if beam_search :
175+ break
157176 if stopping_criteria is not None :
158177 if stopping_criteria (torch .tensor (ret ), None ):
159178 break
160179 elif ret [0 ][- 1 ] == self .tokenizer .eos_token_id or \
161- (max_new_tokens != - 1 and out_count > max_new_tokens ):
180+ (max_new_tokens != - 1 and out_count > max_new_tokens ):
162181 break
163182 out_count += 1
164183 if streamer :
165184 streamer .end ()
166-
185+
167186 self .generate_round += 1
168187 return ret
169188
170189 def is_token_end (self ):
171190 return self .model .is_token_end ()
191+
192+ def __call__ (self , input_ids , reinit = False , ** kwargs ):
193+ if self .model is None :
194+ self .init_from_bin (self .model_type , self .bin_file , ** kwargs )
195+ self .generate_round = 0
196+ elif reinit :
197+ self .model .reinit ()
198+ self .generate_round = 0
199+ return self .model .evaluate (input_ids .tolist ())
0 commit comments