11import contextlib
2+ from io import BytesIO
23import time
4+ import requests
5+ import os
6+ from PIL import Image
37from pathlib import Path
48from tempfile import TemporaryDirectory
59from typing import AsyncIterator , List , Optional
10+ from huggingface_hub import snapshot_download
611
712from loguru import logger
813from PIL import Image
914from transformers import (
1015 AutoConfig ,
16+ AutoProcessor ,
1117 PreTrainedTokenizer ,
1218 PreTrainedTokenizerFast ,
1319 TextIteratorStreamer ,
1622from threading import Thread
1723
1824from optimum .intel import OVModelForCausalLM , OVWeightQuantizationConfig
19-
25+ from embeddedllm . backend . ov_phi3_vision import OvPhi3Vision
2026from embeddedllm .inputs import PromptInputs
2127from embeddedllm .protocol import CompletionOutput , RequestOutput
2228from embeddedllm .sampling_params import SamplingParams
2733
2834class OpenVinoEngine (BaseLLMEngine ):
2935 def __init__ (self , model_path : str , vision : bool , device : str = "gpu" ):
36+ self .vision = vision
3037 self .model_path = model_path
38+ self .device = device
39+
3140 self .model_config : AutoConfig = AutoConfig .from_pretrained (
32- self .model_path , trust_remote_code = True
41+ self .model_path ,
42+ trust_remote_code = True
3343 )
34- self .device = device
3544
3645 # model_config is to find out the max length of the model
3746 self .max_model_len = _get_and_verify_max_len (
@@ -40,51 +49,76 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
4049 disable_sliding_window = False ,
4150 sliding_window_len = self .get_hf_config_sliding_window (),
4251 )
43-
4452 logger .info ("Model Context Length: " + str (self .max_model_len ))
45-
53+
4654 try :
4755 logger .info ("Attempt to load fast tokenizer" )
4856 self .tokenizer = PreTrainedTokenizerFast .from_pretrained (self .model_path )
4957 except Exception :
5058 logger .info ("Attempt to load slower tokenizer" )
5159 self .tokenizer = PreTrainedTokenizer .from_pretrained (self .model_path )
52-
53- try :
54- self .model = OVModelForCausalLM .from_pretrained (
55- model_path , trust_remote_code = True , export = False , device = self .device
56- )
57- except Exception as e :
58- model = OVModelForCausalLM .from_pretrained (
59- model_path ,
60- trust_remote_code = True ,
61- export = True ,
62- quantization_config = OVWeightQuantizationConfig (
63- ** {
64- "bits" : 4 ,
65- "ratio" : 1.0 ,
66- "sym" : True ,
67- "group_size" : 128 ,
68- "all_layers" : None ,
69- }
70- ),
71- )
72- self .model = model .to (self .device )
73-
74- logger .info ("Model loaded" )
7560 self .tokenizer_stream = TextIteratorStreamer (
76- self .tokenizer , skip_prompt = True , skip_special_tokens = True
61+ self .tokenizer ,
62+ skip_prompt = True ,
63+ skip_special_tokens = True
7764 )
7865 logger .info ("Tokenizer created" )
66+
67+ # non vision
68+ if not vision :
69+ try :
70+ self .model = OVModelForCausalLM .from_pretrained (
71+ self .model_path ,
72+ trust_remote_code = True ,
73+ export = False ,
74+ device = self .device
75+ )
76+ except Exception as e :
77+ model = OVModelForCausalLM .from_pretrained (
78+ self .model_path ,
79+ trust_remote_code = True ,
80+ export = True ,
81+ quantization_config = OVWeightQuantizationConfig (
82+ ** {
83+ "bits" : 4 ,
84+ "ratio" : 1.0 ,
85+ "sym" : True ,
86+ "group_size" : 128 ,
87+ "all_layers" : None ,
88+ }
89+ ),
90+ )
91+ self .model = model .to (self .device )
92+
93+ logger .info ("Model loaded" )
94+
95+ # vision
96+ elif self .vision :
97+ logger .info ("Your model is a vision model" )
98+
99+ # snapshot_download vision model if model path provided
100+ if not os .path .exists (model_path ):
101+ snapshot_path = snapshot_download (
102+ repo_id = model_path ,
103+ allow_patterns = None ,
104+ repo_type = "model" ,
105+ )
106+ self .model_path = snapshot_path
107+
108+ # it is case sensitive, only receive all char captilized only
109+ self .model = OvPhi3Vision (
110+ self .model_path ,
111+ self .device .upper ()
112+ )
113+ logger .info ("Model loaded" )
114+
115+ self .processor = AutoProcessor .from_pretrained (
116+ self .model_path ,
117+ trust_remote_code = True
118+ )
119+ logger .info ("Processor loaded" )
120+ print ("processor directory: " ,dir (self .processor ))
79121
80- self .vision = vision
81-
82- # if self.vision:
83- # self.onnx_processor = self.model.create_multimodal_processor()
84- # self.processor = AutoImageProcessor.from_pretrained(
85- # self.model_path, trust_remote_code=True
86- # )
87- # print(dir(self.processor))
88122
89123 async def generate_vision (
90124 self ,
@@ -93,7 +127,81 @@ async def generate_vision(
93127 request_id : str ,
94128 stream : bool = True ,
95129 ) -> AsyncIterator [RequestOutput ]:
96- raise NotImplementedError (f"`generate_vision` yet to be implemented." )
130+ # only work if vision is set to True
131+ if not self .vision :
132+ raise ValueError ("Your model is not a vision model. Please set vision=True when initializing the engine." )
133+
134+ prompt_text = inputs ['prompt' ]
135+ input_tokens = self .tokenizer .encode (prompt_text )
136+ file_data = inputs ["multi_modal_data" ][0 ]["image_pixel_data" ]
137+ mime_type = inputs ["multi_modal_data" ][0 ]["mime_type" ]
138+ print (f"Detected MIME type: { mime_type } " )
139+
140+ assert "image" in mime_type
141+
142+ image = Image .open (BytesIO (file_data ))
143+
144+ messages = [
145+ {'role' : 'user' , 'content' : f'<|image_1|>\n { prompt_text } ' }
146+ ]
147+ prompt = self .processor .tokenizer .apply_chat_template (
148+ messages ,
149+ tokenize = False ,
150+ add_generation_prompt = True
151+ )
152+ print ("Prompt: " ,prompt )
153+
154+ try :
155+ inputs = self .processor (prompt , [image ], return_tensors = "pt" )
156+ print (f"Processed inputs" )
157+ except Exception as e :
158+ print (f"Error processing inputs: { e } " )
159+
160+
161+ token_list : List [int ] = []
162+ output_text : str = ""
163+
164+ try :
165+ generation_options = {
166+ 'max_new_tokens' : sampling_params .max_new_tokens ,
167+ 'do_sample' : False ,
168+ }
169+ token_list = self .model .generate (
170+ ** inputs ,
171+ eos_token_id = self .processor .tokenizer .eos_token_id ,
172+ ** generation_options
173+ )
174+ print (f"Generated token list" )
175+ except Exception as e :
176+ print (f"Error during token generation: { e } " )
177+
178+ # Decode each element in the response
179+ try :
180+ decoded_text = [self .processor .tokenizer .decode (ids , skip_special_tokens = True ) for ids in token_list ]
181+ print (f"Decoded text: { decoded_text } " )
182+ except Exception as e :
183+ print (f"Error decoding text: { e } " )
184+
185+ # Join the decoded text if needed
186+ output_text = ' ' .join (decoded_text ).strip ()
187+ print (output_text )
188+
189+ yield RequestOutput (
190+ request_id = request_id ,
191+ prompt = inputs ,
192+ prompt_token_ids = input_tokens ,
193+ finished = True ,
194+ outputs = [
195+ CompletionOutput (
196+ index = 0 ,
197+ text = output_text ,
198+ token_ids = token_list ,
199+ cumulative_logprob = - 1.0 ,
200+ finish_reason = "stop" ,
201+ )
202+ ],
203+ )
204+
97205
98206 async def generate (
99207 self ,
0 commit comments