Skip to content

Commit aeca16a

Browse files
authored
Merge pull request #33 from EmbeddedLLM/szeyu-vision-1
[FEAT] Added OpenVINO vision model support #32
2 parents 4887326 + 86b59d6 commit aeca16a

File tree

3 files changed

+855
-41
lines changed

3 files changed

+855
-41
lines changed

src/embeddedllm/backend/onnxruntime_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, model_path: str, vision: bool, device: str = "cpu"):
4848
allow_patterns=None,
4949
repo_type="model",
5050
)
51-
model_path = snapshot_path
51+
self.model_path = snapshot_path
5252

5353
self.model_config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
5454
self.device = device

src/embeddedllm/backend/openvino_engine.py

Lines changed: 264 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
import contextlib
2+
from io import BytesIO
23
import time
4+
import os
5+
from PIL import Image
36
from pathlib import Path
47
from tempfile import TemporaryDirectory
58
from typing import AsyncIterator, List, Optional
9+
from huggingface_hub import snapshot_download
610

711
from loguru import logger
812
from PIL import Image
913
from transformers import (
1014
AutoConfig,
15+
AutoProcessor,
16+
TextStreamer,
1117
PreTrainedTokenizer,
1218
PreTrainedTokenizerFast,
1319
TextIteratorStreamer,
@@ -16,7 +22,7 @@
1622
from threading import Thread
1723

1824
from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
19-
25+
from embeddedllm.backend.ov_phi3_vision import OvPhi3Vision
2026
from embeddedllm.inputs import PromptInputs
2127
from embeddedllm.protocol import CompletionOutput, RequestOutput
2228
from embeddedllm.sampling_params import SamplingParams
@@ -27,11 +33,14 @@
2733

2834
class OpenVinoEngine(BaseLLMEngine):
2935
def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
36+
self.vision = vision
3037
self.model_path = model_path
38+
self.device = device
39+
3140
self.model_config: AutoConfig = AutoConfig.from_pretrained(
32-
self.model_path, trust_remote_code=True
41+
self.model_path,
42+
trust_remote_code=True
3343
)
34-
self.device = device
3544

3645
# model_config is to find out the max length of the model
3746
self.max_model_len = _get_and_verify_max_len(
@@ -40,51 +49,88 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
4049
disable_sliding_window=False,
4150
sliding_window_len=self.get_hf_config_sliding_window(),
4251
)
43-
4452
logger.info("Model Context Length: " + str(self.max_model_len))
45-
53+
4654
try:
4755
logger.info("Attempt to load fast tokenizer")
4856
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path)
4957
except Exception:
5058
logger.info("Attempt to load slower tokenizer")
5159
self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)
52-
53-
try:
54-
self.model = OVModelForCausalLM.from_pretrained(
55-
model_path, trust_remote_code=True, export=False, device=self.device
60+
logger.info("Tokenizer created")
61+
62+
# non vision
63+
if not vision:
64+
self.tokenizer_stream = TextIteratorStreamer(
65+
self.tokenizer,
66+
skip_prompt=True,
67+
skip_special_tokens=True
5668
)
57-
except Exception as e:
58-
model = OVModelForCausalLM.from_pretrained(
59-
model_path,
60-
trust_remote_code=True,
61-
export=True,
62-
quantization_config=OVWeightQuantizationConfig(
69+
try:
70+
self.model = OVModelForCausalLM.from_pretrained(
71+
self.model_path,
72+
trust_remote_code=True,
73+
export=False,
74+
device=self.device
75+
)
76+
except Exception as e:
77+
model = OVModelForCausalLM.from_pretrained(
78+
self.model_path,
79+
trust_remote_code=True,
80+
export=True,
81+
quantization_config=OVWeightQuantizationConfig(
82+
**{
83+
"bits": 4,
84+
"ratio": 1.0,
85+
"sym": True,
86+
"group_size": 128,
87+
"all_layers": None,
88+
}
89+
),
90+
)
91+
self.model = model.to(self.device)
92+
93+
logger.info("Model loaded")
94+
95+
# vision
96+
elif self.vision:
97+
logger.info("Your model is a vision model")
98+
99+
# snapshot_download vision model if model path provided
100+
if not os.path.exists(model_path):
101+
snapshot_path = snapshot_download(
102+
repo_id=model_path,
103+
allow_patterns=None,
104+
repo_type="model",
105+
)
106+
self.model_path = snapshot_path
107+
108+
try:
109+
# it is case sensitive, only receive all char captilized only
110+
self.model = OvPhi3Vision(
111+
self.model_path,
112+
self.device.upper()
113+
)
114+
logger.info("Model loaded")
115+
116+
self.processor = AutoProcessor.from_pretrained(
117+
self.model_path,
118+
trust_remote_code=True
119+
)
120+
logger.info("Processor loaded")
121+
print("processor directory: ",dir(self.processor))
122+
self.tokenizer_stream = TextIteratorStreamer(
123+
self.processor,
63124
**{
64-
"bits": 4,
65-
"ratio": 1.0,
66-
"sym": True,
67-
"group_size": 128,
68-
"all_layers": None,
69-
}
70-
),
71-
)
72-
self.model = model.to(self.device)
73-
74-
logger.info("Model loaded")
75-
self.tokenizer_stream = TextIteratorStreamer(
76-
self.tokenizer, skip_prompt=True, skip_special_tokens=True
77-
)
78-
logger.info("Tokenizer created")
79-
80-
self.vision = vision
81-
82-
# if self.vision:
83-
# self.onnx_processor = self.model.create_multimodal_processor()
84-
# self.processor = AutoImageProcessor.from_pretrained(
85-
# self.model_path, trust_remote_code=True
86-
# )
87-
# print(dir(self.processor))
125+
"skip_special_tokens": True,
126+
"skip_prompt": True,
127+
"clean_up_tokenization_spaces": False,
128+
},
129+
)
130+
131+
except Exception as e:
132+
logger.error("EmbeddedLLM Engine only support Phi 3 Vision Model.")
133+
exit()
88134

89135
async def generate_vision(
90136
self,
@@ -93,7 +139,185 @@ async def generate_vision(
93139
request_id: str,
94140
stream: bool = True,
95141
) -> AsyncIterator[RequestOutput]:
96-
raise NotImplementedError(f"`generate_vision` yet to be implemented.")
142+
# only work if vision is set to True
143+
if not self.vision:
144+
raise ValueError("Your model is not a vision model. Please set vision=True when initializing the engine.")
145+
146+
prompt_text = inputs['prompt']
147+
input_tokens = self.tokenizer.encode(prompt_text)
148+
file_data = inputs["multi_modal_data"][0]["image_pixel_data"]
149+
mime_type = inputs["multi_modal_data"][0]["mime_type"]
150+
print(f"Detected MIME type: {mime_type}")
151+
152+
assert "image" in mime_type
153+
154+
image = Image.open(BytesIO(file_data))
155+
image_token_length = self.processor.calc_num_image_tokens(image)[0]
156+
prompt_token_length = len(self.tokenizer.encode(prompt_text, return_tensors="pt")[0])
157+
158+
input_token_length = image_token_length + prompt_token_length
159+
160+
# logger.debug(f"Prompt token length: {prompt_token_length}")
161+
# logger.debug(f"Image token length: {image_token_length}")
162+
163+
max_tokens = sampling_params.max_tokens
164+
165+
assert input_token_length is not None
166+
167+
if input_token_length + max_tokens > self.max_model_len:
168+
raise ValueError("Exceed Context Length")
169+
170+
messages = [
171+
{'role': 'user', 'content': f'<|image_1|>\n{prompt_text}'}
172+
]
173+
prompt = self.processor.tokenizer.apply_chat_template(
174+
messages,
175+
tokenize=False,
176+
add_generation_prompt=True
177+
)
178+
# print("Prompt: ", prompt)
179+
180+
inputs = self.processor(prompt, [image], return_tensors="pt")
181+
182+
generation_options = {
183+
'max_new_tokens': max_tokens,
184+
'do_sample': False,
185+
}
186+
187+
token_list: List[int] = []
188+
output_text: str = ""
189+
if stream:
190+
generation_options["streamer"] = self.tokenizer_stream
191+
# Include the inputs in the generation_options
192+
generation_kwargs = {**inputs, **generation_options}
193+
194+
if RECORD_TIMING:
195+
started_timestamp = time.time()
196+
first_token_timestamp = 0
197+
first = True
198+
new_tokens = []
199+
200+
try:
201+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
202+
thread.start()
203+
output_text = ""
204+
first = True
205+
for new_text in self.tokenizer_stream:
206+
if new_text == "":
207+
continue
208+
if RECORD_TIMING:
209+
if first:
210+
first_token_timestamp = time.time()
211+
first = False
212+
output_text += new_text
213+
token_list = self.processor.tokenizer.encode(output_text, return_tensors="pt")
214+
215+
yield RequestOutput(
216+
request_id=request_id,
217+
prompt=inputs,
218+
prompt_token_ids=input_tokens,
219+
finished=False,
220+
outputs=[
221+
CompletionOutput(
222+
index=0,
223+
text=output_text,
224+
token_ids=token_list[0],
225+
cumulative_logprob=-1.0,
226+
)
227+
],
228+
)
229+
230+
if RECORD_TIMING:
231+
new_tokens = token_list[0]
232+
233+
yield RequestOutput(
234+
request_id=request_id,
235+
prompt=inputs,
236+
prompt_token_ids=input_tokens,
237+
finished=True,
238+
outputs=[
239+
CompletionOutput(
240+
index=0,
241+
text=output_text,
242+
token_ids=token_list[0],
243+
cumulative_logprob=-1.0,
244+
finish_reason="stop",
245+
)
246+
],
247+
)
248+
249+
if RECORD_TIMING:
250+
prompt_time = first_token_timestamp - started_timestamp
251+
run_time = time.time() - first_token_timestamp
252+
logger.info(
253+
f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps"
254+
)
255+
256+
except Exception as e:
257+
logger.error(str(e))
258+
259+
error_output = RequestOutput(
260+
prompt=inputs,
261+
prompt_token_ids=input_tokens,
262+
finished=True,
263+
request_id=request_id,
264+
outputs=[
265+
CompletionOutput(
266+
index=0,
267+
text=output_text,
268+
token_ids=token_list,
269+
cumulative_logprob=-1.0,
270+
finish_reason="error",
271+
stop_reason=str(e),
272+
)
273+
],
274+
)
275+
yield error_output
276+
277+
else:
278+
try:
279+
token_list = self.model.generate(**inputs, **generation_options)[0]
280+
output_text = self.processor.tokenizer.decode(
281+
token_list, skip_special_tokens=True
282+
)
283+
284+
yield RequestOutput(
285+
request_id=request_id,
286+
prompt=inputs,
287+
prompt_token_ids=input_tokens,
288+
finished=True,
289+
outputs=[
290+
CompletionOutput(
291+
index=0,
292+
text=output_text,
293+
token_ids=token_list,
294+
cumulative_logprob=-1.0,
295+
finish_reason="stop",
296+
)
297+
],
298+
)
299+
300+
except Exception as e:
301+
logger.error(str(e))
302+
303+
error_output = RequestOutput(
304+
prompt=inputs,
305+
prompt_token_ids=input_tokens,
306+
finished=True,
307+
request_id=request_id,
308+
outputs=[
309+
CompletionOutput(
310+
index=0,
311+
text=output_text,
312+
token_ids=token_list,
313+
cumulative_logprob=-1.0,
314+
finish_reason="error",
315+
stop_reason=str(e),
316+
)
317+
],
318+
)
319+
yield error_output
320+
97321

98322
async def generate(
99323
self,

0 commit comments

Comments
 (0)