11import contextlib
2+ from io import BytesIO
23import time
4+ import os
5+ from PIL import Image
36from pathlib import Path
47from tempfile import TemporaryDirectory
58from typing import AsyncIterator , List , Optional
9+ from huggingface_hub import snapshot_download
610
711from loguru import logger
812from PIL import Image
913from transformers import (
1014 AutoConfig ,
15+ AutoProcessor ,
16+ TextStreamer ,
1117 PreTrainedTokenizer ,
1218 PreTrainedTokenizerFast ,
1319 TextIteratorStreamer ,
1622from threading import Thread
1723
1824from optimum .intel import OVModelForCausalLM , OVWeightQuantizationConfig
19-
25+ from embeddedllm . backend . ov_phi3_vision import OvPhi3Vision
2026from embeddedllm .inputs import PromptInputs
2127from embeddedllm .protocol import CompletionOutput , RequestOutput
2228from embeddedllm .sampling_params import SamplingParams
2733
2834class OpenVinoEngine (BaseLLMEngine ):
2935 def __init__ (self , model_path : str , vision : bool , device : str = "gpu" ):
36+ self .vision = vision
3037 self .model_path = model_path
38+ self .device = device
39+
3140 self .model_config : AutoConfig = AutoConfig .from_pretrained (
32- self .model_path , trust_remote_code = True
41+ self .model_path ,
42+ trust_remote_code = True
3343 )
34- self .device = device
3544
3645 # model_config is to find out the max length of the model
3746 self .max_model_len = _get_and_verify_max_len (
@@ -40,51 +49,88 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
4049 disable_sliding_window = False ,
4150 sliding_window_len = self .get_hf_config_sliding_window (),
4251 )
43-
4452 logger .info ("Model Context Length: " + str (self .max_model_len ))
45-
53+
4654 try :
4755 logger .info ("Attempt to load fast tokenizer" )
4856 self .tokenizer = PreTrainedTokenizerFast .from_pretrained (self .model_path )
4957 except Exception :
5058 logger .info ("Attempt to load slower tokenizer" )
5159 self .tokenizer = PreTrainedTokenizer .from_pretrained (self .model_path )
52-
53- try :
54- self .model = OVModelForCausalLM .from_pretrained (
55- model_path , trust_remote_code = True , export = False , device = self .device
60+ logger .info ("Tokenizer created" )
61+
62+ # non vision
63+ if not vision :
64+ self .tokenizer_stream = TextIteratorStreamer (
65+ self .tokenizer ,
66+ skip_prompt = True ,
67+ skip_special_tokens = True
5668 )
57- except Exception as e :
58- model = OVModelForCausalLM .from_pretrained (
59- model_path ,
60- trust_remote_code = True ,
61- export = True ,
62- quantization_config = OVWeightQuantizationConfig (
69+ try :
70+ self .model = OVModelForCausalLM .from_pretrained (
71+ self .model_path ,
72+ trust_remote_code = True ,
73+ export = False ,
74+ device = self .device
75+ )
76+ except Exception as e :
77+ model = OVModelForCausalLM .from_pretrained (
78+ self .model_path ,
79+ trust_remote_code = True ,
80+ export = True ,
81+ quantization_config = OVWeightQuantizationConfig (
82+ ** {
83+ "bits" : 4 ,
84+ "ratio" : 1.0 ,
85+ "sym" : True ,
86+ "group_size" : 128 ,
87+ "all_layers" : None ,
88+ }
89+ ),
90+ )
91+ self .model = model .to (self .device )
92+
93+ logger .info ("Model loaded" )
94+
95+ # vision
96+ elif self .vision :
97+ logger .info ("Your model is a vision model" )
98+
99+ # snapshot_download vision model if model path provided
100+ if not os .path .exists (model_path ):
101+ snapshot_path = snapshot_download (
102+ repo_id = model_path ,
103+ allow_patterns = None ,
104+ repo_type = "model" ,
105+ )
106+ self .model_path = snapshot_path
107+
108+ try :
109+ # it is case sensitive, only receive all char captilized only
110+ self .model = OvPhi3Vision (
111+ self .model_path ,
112+ self .device .upper ()
113+ )
114+ logger .info ("Model loaded" )
115+
116+ self .processor = AutoProcessor .from_pretrained (
117+ self .model_path ,
118+ trust_remote_code = True
119+ )
120+ logger .info ("Processor loaded" )
121+ print ("processor directory: " ,dir (self .processor ))
122+ self .tokenizer_stream = TextIteratorStreamer (
123+ self .processor ,
63124 ** {
64- "bits" : 4 ,
65- "ratio" : 1.0 ,
66- "sym" : True ,
67- "group_size" : 128 ,
68- "all_layers" : None ,
69- }
70- ),
71- )
72- self .model = model .to (self .device )
73-
74- logger .info ("Model loaded" )
75- self .tokenizer_stream = TextIteratorStreamer (
76- self .tokenizer , skip_prompt = True , skip_special_tokens = True
77- )
78- logger .info ("Tokenizer created" )
79-
80- self .vision = vision
81-
82- # if self.vision:
83- # self.onnx_processor = self.model.create_multimodal_processor()
84- # self.processor = AutoImageProcessor.from_pretrained(
85- # self.model_path, trust_remote_code=True
86- # )
87- # print(dir(self.processor))
125+ "skip_special_tokens" : True ,
126+ "skip_prompt" : True ,
127+ "clean_up_tokenization_spaces" : False ,
128+ },
129+ )
130+
131+ except Exception as e :
132+ logger .error ("EmbeddedLLM Engine only support Phi 3 Vision Model." )
133+ exit ()
88134
89135 async def generate_vision (
90136 self ,
@@ -93,7 +139,185 @@ async def generate_vision(
93139 request_id : str ,
94140 stream : bool = True ,
95141 ) -> AsyncIterator [RequestOutput ]:
96- raise NotImplementedError (f"`generate_vision` yet to be implemented." )
142+ # only work if vision is set to True
143+ if not self .vision :
144+ raise ValueError ("Your model is not a vision model. Please set vision=True when initializing the engine." )
145+
146+ prompt_text = inputs ['prompt' ]
147+ input_tokens = self .tokenizer .encode (prompt_text )
148+ file_data = inputs ["multi_modal_data" ][0 ]["image_pixel_data" ]
149+ mime_type = inputs ["multi_modal_data" ][0 ]["mime_type" ]
150+ print (f"Detected MIME type: { mime_type } " )
151+
152+ assert "image" in mime_type
153+
154+ image = Image .open (BytesIO (file_data ))
155+ image_token_length = self .processor .calc_num_image_tokens (image )[0 ]
156+ prompt_token_length = len (self .tokenizer .encode (prompt_text , return_tensors = "pt" )[0 ])
157+
158+ input_token_length = image_token_length + prompt_token_length
159+
160+ # logger.debug(f"Prompt token length: {prompt_token_length}")
161+ # logger.debug(f"Image token length: {image_token_length}")
162+
163+ max_tokens = sampling_params .max_tokens
164+
165+ assert input_token_length is not None
166+
167+ if input_token_length + max_tokens > self .max_model_len :
168+ raise ValueError ("Exceed Context Length" )
169+
170+ messages = [
171+ {'role' : 'user' , 'content' : f'<|image_1|>\n { prompt_text } ' }
172+ ]
173+ prompt = self .processor .tokenizer .apply_chat_template (
174+ messages ,
175+ tokenize = False ,
176+ add_generation_prompt = True
177+ )
178+ # print("Prompt: ", prompt)
179+
180+ inputs = self .processor (prompt , [image ], return_tensors = "pt" )
181+
182+ generation_options = {
183+ 'max_new_tokens' : max_tokens ,
184+ 'do_sample' : False ,
185+ }
186+
187+ token_list : List [int ] = []
188+ output_text : str = ""
189+ if stream :
190+ generation_options ["streamer" ] = self .tokenizer_stream
191+ # Include the inputs in the generation_options
192+ generation_kwargs = {** inputs , ** generation_options }
193+
194+ if RECORD_TIMING :
195+ started_timestamp = time .time ()
196+ first_token_timestamp = 0
197+ first = True
198+ new_tokens = []
199+
200+ try :
201+ thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
202+ thread .start ()
203+ output_text = ""
204+ first = True
205+ for new_text in self .tokenizer_stream :
206+ if new_text == "" :
207+ continue
208+ if RECORD_TIMING :
209+ if first :
210+ first_token_timestamp = time .time ()
211+ first = False
212+ output_text += new_text
213+ token_list = self .processor .tokenizer .encode (output_text , return_tensors = "pt" )
214+
215+ yield RequestOutput (
216+ request_id = request_id ,
217+ prompt = inputs ,
218+ prompt_token_ids = input_tokens ,
219+ finished = False ,
220+ outputs = [
221+ CompletionOutput (
222+ index = 0 ,
223+ text = output_text ,
224+ token_ids = token_list [0 ],
225+ cumulative_logprob = - 1.0 ,
226+ )
227+ ],
228+ )
229+
230+ if RECORD_TIMING :
231+ new_tokens = token_list [0 ]
232+
233+ yield RequestOutput (
234+ request_id = request_id ,
235+ prompt = inputs ,
236+ prompt_token_ids = input_tokens ,
237+ finished = True ,
238+ outputs = [
239+ CompletionOutput (
240+ index = 0 ,
241+ text = output_text ,
242+ token_ids = token_list [0 ],
243+ cumulative_logprob = - 1.0 ,
244+ finish_reason = "stop" ,
245+ )
246+ ],
247+ )
248+
249+ if RECORD_TIMING :
250+ prompt_time = first_token_timestamp - started_timestamp
251+ run_time = time .time () - first_token_timestamp
252+ logger .info (
253+ f"Prompt length: { len (input_tokens )} , New tokens: { len (new_tokens )} , Time to first: { (prompt_time ):.2f} s, Prompt tokens per second: { len (input_tokens )/ prompt_time :.2f} tps, New tokens per second: { len (new_tokens )/ run_time :.2f} tps"
254+ )
255+
256+ except Exception as e :
257+ logger .error (str (e ))
258+
259+ error_output = RequestOutput (
260+ prompt = inputs ,
261+ prompt_token_ids = input_tokens ,
262+ finished = True ,
263+ request_id = request_id ,
264+ outputs = [
265+ CompletionOutput (
266+ index = 0 ,
267+ text = output_text ,
268+ token_ids = token_list ,
269+ cumulative_logprob = - 1.0 ,
270+ finish_reason = "error" ,
271+ stop_reason = str (e ),
272+ )
273+ ],
274+ )
275+ yield error_output
276+
277+ else :
278+ try :
279+ token_list = self .model .generate (** inputs , ** generation_options )[0 ]
280+ output_text = self .processor .tokenizer .decode (
281+ token_list , skip_special_tokens = True
282+ )
283+
284+ yield RequestOutput (
285+ request_id = request_id ,
286+ prompt = inputs ,
287+ prompt_token_ids = input_tokens ,
288+ finished = True ,
289+ outputs = [
290+ CompletionOutput (
291+ index = 0 ,
292+ text = output_text ,
293+ token_ids = token_list ,
294+ cumulative_logprob = - 1.0 ,
295+ finish_reason = "stop" ,
296+ )
297+ ],
298+ )
299+
300+ except Exception as e :
301+ logger .error (str (e ))
302+
303+ error_output = RequestOutput (
304+ prompt = inputs ,
305+ prompt_token_ids = input_tokens ,
306+ finished = True ,
307+ request_id = request_id ,
308+ outputs = [
309+ CompletionOutput (
310+ index = 0 ,
311+ text = output_text ,
312+ token_ids = token_list ,
313+ cumulative_logprob = - 1.0 ,
314+ finish_reason = "error" ,
315+ stop_reason = str (e ),
316+ )
317+ ],
318+ )
319+ yield error_output
320+
97321
98322 async def generate (
99323 self ,
0 commit comments