Skip to content

Commit 91b3808

Browse files
committed
added generation vision without streaming version
1 parent c4518ad commit 91b3808

File tree

1 file changed

+146
-38
lines changed

1 file changed

+146
-38
lines changed

src/embeddedllm/backend/openvino_engine.py

Lines changed: 146 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
import contextlib
2+
from io import BytesIO
23
import time
4+
import requests
5+
import os
6+
from PIL import Image
37
from pathlib import Path
48
from tempfile import TemporaryDirectory
59
from typing import AsyncIterator, List, Optional
10+
from huggingface_hub import snapshot_download
611

712
from loguru import logger
813
from PIL import Image
914
from transformers import (
1015
AutoConfig,
16+
AutoProcessor,
1117
PreTrainedTokenizer,
1218
PreTrainedTokenizerFast,
1319
TextIteratorStreamer,
@@ -16,7 +22,7 @@
1622
from threading import Thread
1723

1824
from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
19-
25+
from embeddedllm.backend.ov_phi3_vision import OvPhi3Vision
2026
from embeddedllm.inputs import PromptInputs
2127
from embeddedllm.protocol import CompletionOutput, RequestOutput
2228
from embeddedllm.sampling_params import SamplingParams
@@ -27,11 +33,14 @@
2733

2834
class OpenVinoEngine(BaseLLMEngine):
2935
def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
36+
self.vision = vision
3037
self.model_path = model_path
38+
self.device = device
39+
3140
self.model_config: AutoConfig = AutoConfig.from_pretrained(
32-
self.model_path, trust_remote_code=True
41+
self.model_path,
42+
trust_remote_code=True
3343
)
34-
self.device = device
3544

3645
# model_config is to find out the max length of the model
3746
self.max_model_len = _get_and_verify_max_len(
@@ -40,51 +49,76 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
4049
disable_sliding_window=False,
4150
sliding_window_len=self.get_hf_config_sliding_window(),
4251
)
43-
4452
logger.info("Model Context Length: " + str(self.max_model_len))
45-
53+
4654
try:
4755
logger.info("Attempt to load fast tokenizer")
4856
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path)
4957
except Exception:
5058
logger.info("Attempt to load slower tokenizer")
5159
self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)
52-
53-
try:
54-
self.model = OVModelForCausalLM.from_pretrained(
55-
model_path, trust_remote_code=True, export=False, device=self.device
56-
)
57-
except Exception as e:
58-
model = OVModelForCausalLM.from_pretrained(
59-
model_path,
60-
trust_remote_code=True,
61-
export=True,
62-
quantization_config=OVWeightQuantizationConfig(
63-
**{
64-
"bits": 4,
65-
"ratio": 1.0,
66-
"sym": True,
67-
"group_size": 128,
68-
"all_layers": None,
69-
}
70-
),
71-
)
72-
self.model = model.to(self.device)
73-
74-
logger.info("Model loaded")
7560
self.tokenizer_stream = TextIteratorStreamer(
76-
self.tokenizer, skip_prompt=True, skip_special_tokens=True
61+
self.tokenizer,
62+
skip_prompt=True,
63+
skip_special_tokens=True
7764
)
7865
logger.info("Tokenizer created")
66+
67+
# non vision
68+
if not vision:
69+
try:
70+
self.model = OVModelForCausalLM.from_pretrained(
71+
self.model_path,
72+
trust_remote_code=True,
73+
export=False,
74+
device=self.device
75+
)
76+
except Exception as e:
77+
model = OVModelForCausalLM.from_pretrained(
78+
self.model_path,
79+
trust_remote_code=True,
80+
export=True,
81+
quantization_config=OVWeightQuantizationConfig(
82+
**{
83+
"bits": 4,
84+
"ratio": 1.0,
85+
"sym": True,
86+
"group_size": 128,
87+
"all_layers": None,
88+
}
89+
),
90+
)
91+
self.model = model.to(self.device)
92+
93+
logger.info("Model loaded")
94+
95+
# vision
96+
elif self.vision:
97+
logger.info("Your model is a vision model")
98+
99+
# snapshot_download vision model if model path provided
100+
if not os.path.exists(model_path):
101+
snapshot_path = snapshot_download(
102+
repo_id=model_path,
103+
allow_patterns=None,
104+
repo_type="model",
105+
)
106+
self.model_path = snapshot_path
107+
108+
# it is case sensitive, only receive all char captilized only
109+
self.model = OvPhi3Vision(
110+
self.model_path,
111+
self.device.upper()
112+
)
113+
logger.info("Model loaded")
114+
115+
self.processor = AutoProcessor.from_pretrained(
116+
self.model_path,
117+
trust_remote_code=True
118+
)
119+
logger.info("Processor loaded")
120+
print("processor directory: ",dir(self.processor))
79121

80-
self.vision = vision
81-
82-
# if self.vision:
83-
# self.onnx_processor = self.model.create_multimodal_processor()
84-
# self.processor = AutoImageProcessor.from_pretrained(
85-
# self.model_path, trust_remote_code=True
86-
# )
87-
# print(dir(self.processor))
88122

89123
async def generate_vision(
90124
self,
@@ -93,7 +127,81 @@ async def generate_vision(
93127
request_id: str,
94128
stream: bool = True,
95129
) -> AsyncIterator[RequestOutput]:
96-
raise NotImplementedError(f"`generate_vision` yet to be implemented.")
130+
# only work if vision is set to True
131+
if not self.vision:
132+
raise ValueError("Your model is not a vision model. Please set vision=True when initializing the engine.")
133+
134+
prompt_text = inputs['prompt']
135+
input_tokens = self.tokenizer.encode(prompt_text)
136+
file_data = inputs["multi_modal_data"][0]["image_pixel_data"]
137+
mime_type = inputs["multi_modal_data"][0]["mime_type"]
138+
print(f"Detected MIME type: {mime_type}")
139+
140+
assert "image" in mime_type
141+
142+
image = Image.open(BytesIO(file_data))
143+
144+
messages = [
145+
{'role': 'user', 'content': f'<|image_1|>\n{prompt_text}'}
146+
]
147+
prompt = self.processor.tokenizer.apply_chat_template(
148+
messages,
149+
tokenize=False,
150+
add_generation_prompt=True
151+
)
152+
print("Prompt: ",prompt)
153+
154+
try:
155+
inputs = self.processor(prompt, [image], return_tensors="pt")
156+
print(f"Processed inputs")
157+
except Exception as e:
158+
print(f"Error processing inputs: {e}")
159+
160+
161+
token_list: List[int] = []
162+
output_text: str = ""
163+
164+
try:
165+
generation_options = {
166+
'max_new_tokens': sampling_params.max_new_tokens,
167+
'do_sample': False,
168+
}
169+
token_list = self.model.generate(
170+
**inputs,
171+
eos_token_id=self.processor.tokenizer.eos_token_id,
172+
**generation_options
173+
)
174+
print(f"Generated token list")
175+
except Exception as e:
176+
print(f"Error during token generation: {e}")
177+
178+
# Decode each element in the response
179+
try:
180+
decoded_text = [self.processor.tokenizer.decode(ids, skip_special_tokens=True) for ids in token_list]
181+
print(f"Decoded text: {decoded_text}")
182+
except Exception as e:
183+
print(f"Error decoding text: {e}")
184+
185+
# Join the decoded text if needed
186+
output_text = ' '.join(decoded_text).strip()
187+
print(output_text)
188+
189+
yield RequestOutput(
190+
request_id=request_id,
191+
prompt=inputs,
192+
prompt_token_ids=input_tokens,
193+
finished=True,
194+
outputs=[
195+
CompletionOutput(
196+
index=0,
197+
text=output_text,
198+
token_ids=token_list,
199+
cumulative_logprob=-1.0,
200+
finish_reason="stop",
201+
)
202+
],
203+
)
204+
97205

98206
async def generate(
99207
self,

0 commit comments

Comments
 (0)