Skip to content

Commit aff0ae2

Browse files
authored
Update openvino_engine.py
* added record timing function * added streaming feature
1 parent 7fe7f74 commit aff0ae2

File tree

1 file changed

+158
-53
lines changed

1 file changed

+158
-53
lines changed

src/embeddedllm/backend/openvino_engine.py

Lines changed: 158 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from transformers import (
1414
AutoConfig,
1515
AutoProcessor,
16+
TextStreamer,
1617
PreTrainedTokenizer,
1718
PreTrainedTokenizerFast,
1819
TextIteratorStreamer,
@@ -56,15 +57,15 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
5657
except Exception:
5758
logger.info("Attempt to load slower tokenizer")
5859
self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path)
59-
self.tokenizer_stream = TextIteratorStreamer(
60-
self.tokenizer,
61-
skip_prompt=True,
62-
skip_special_tokens=True
63-
)
6460
logger.info("Tokenizer created")
6561

6662
# non vision
6763
if not vision:
64+
self.tokenizer_stream = TextIteratorStreamer(
65+
self.tokenizer,
66+
skip_prompt=True,
67+
skip_special_tokens=True
68+
)
6869
try:
6970
self.model = OVModelForCausalLM.from_pretrained(
7071
self.model_path,
@@ -117,6 +118,14 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
117118
)
118119
logger.info("Processor loaded")
119120
print("processor directory: ",dir(self.processor))
121+
self.tokenizer_stream = TextIteratorStreamer(
122+
self.processor,
123+
**{
124+
"skip_special_tokens": True,
125+
"skip_prompt": True,
126+
"clean_up_tokenization_spaces": False,
127+
},
128+
)
120129

121130

122131
async def generate_vision(
@@ -139,15 +148,21 @@ async def generate_vision(
139148
assert "image" in mime_type
140149

141150
image = Image.open(BytesIO(file_data))
142-
input_token_length = self.processor.calc_num_image_tokens(image)[0]
151+
image_token_length = self.processor.calc_num_image_tokens(image)[0]
152+
prompt_token_length = len(self.tokenizer.encode(prompt_text, return_tensors="pt")[0])
153+
154+
input_token_length = image_token_length + prompt_token_length
155+
156+
# logger.debug(f"Prompt token length: {prompt_token_length}")
157+
# logger.debug(f"Image token length: {image_token_length}")
158+
143159
max_tokens = sampling_params.max_tokens
144160

145161
assert input_token_length is not None
146162

147163
if input_token_length + max_tokens > self.max_model_len:
148164
raise ValueError("Exceed Context Length")
149165

150-
151166
messages = [
152167
{'role': 'user', 'content': f'<|image_1|>\n{prompt_text}'}
153168
]
@@ -156,58 +171,148 @@ async def generate_vision(
156171
tokenize=False,
157172
add_generation_prompt=True
158173
)
159-
print("Prompt: ",prompt)
174+
# print("Prompt: ", prompt)
160175

161-
try:
162-
inputs = self.processor(prompt, [image], return_tensors="pt")
163-
print(f"Processed inputs")
164-
except Exception as e:
165-
print(f"Error processing inputs: {e}")
176+
inputs = self.processor(prompt, [image], return_tensors="pt")
166177

178+
generation_options = {
179+
'max_new_tokens': max_tokens,
180+
'do_sample': False,
181+
}
167182

168183
token_list: List[int] = []
169184
output_text: str = ""
170-
171-
try:
172-
generation_options = {
173-
'max_new_tokens': max_tokens,
174-
'do_sample': False,
175-
}
176-
token_list = self.model.generate(
177-
**inputs,
178-
eos_token_id=self.processor.tokenizer.eos_token_id,
179-
**generation_options
180-
)
181-
print(f"Generated token list")
182-
except Exception as e:
183-
print(f"Error during token generation: {e}")
184-
185-
# Decode each element in the response
186-
try:
187-
decoded_text = [self.processor.tokenizer.decode(ids, skip_special_tokens=True) for ids in token_list]
188-
print(f"Decoded text: {decoded_text}")
189-
except Exception as e:
190-
print(f"Error decoding text: {e}")
191-
192-
# Join the decoded text if needed
193-
output_text = ' '.join(decoded_text).strip()
194-
print(output_text)
195-
196-
yield RequestOutput(
197-
request_id=request_id,
198-
prompt=inputs,
199-
prompt_token_ids=input_tokens,
200-
finished=True,
201-
outputs=[
202-
CompletionOutput(
203-
index=0,
204-
text=output_text,
205-
token_ids=token_list,
206-
cumulative_logprob=-1.0,
207-
finish_reason="stop",
185+
if stream:
186+
generation_options["streamer"] = self.tokenizer_stream
187+
# Include the inputs in the generation_options
188+
generation_kwargs = {**inputs, **generation_options}
189+
190+
if RECORD_TIMING:
191+
started_timestamp = time.time()
192+
first_token_timestamp = 0
193+
first = True
194+
new_tokens = []
195+
196+
try:
197+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
198+
thread.start()
199+
output_text = ""
200+
first = True
201+
for new_text in self.tokenizer_stream:
202+
if new_text == "":
203+
continue
204+
if RECORD_TIMING:
205+
if first:
206+
first_token_timestamp = time.time()
207+
first = False
208+
output_text += new_text
209+
token_list = self.processor.tokenizer.encode(output_text, return_tensors="pt")
210+
211+
yield RequestOutput(
212+
request_id=request_id,
213+
prompt=inputs,
214+
prompt_token_ids=input_tokens,
215+
finished=False,
216+
outputs=[
217+
CompletionOutput(
218+
index=0,
219+
text=output_text,
220+
token_ids=token_list[0],
221+
cumulative_logprob=-1.0,
222+
)
223+
],
224+
)
225+
226+
if RECORD_TIMING:
227+
new_tokens = token_list[0]
228+
229+
yield RequestOutput(
230+
request_id=request_id,
231+
prompt=inputs,
232+
prompt_token_ids=input_tokens,
233+
finished=True,
234+
outputs=[
235+
CompletionOutput(
236+
index=0,
237+
text=output_text,
238+
token_ids=token_list[0],
239+
cumulative_logprob=-1.0,
240+
finish_reason="stop",
241+
)
242+
],
208243
)
209-
],
210-
)
244+
245+
if RECORD_TIMING:
246+
prompt_time = first_token_timestamp - started_timestamp
247+
run_time = time.time() - first_token_timestamp
248+
logger.info(
249+
f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps"
250+
)
251+
252+
except Exception as e:
253+
logger.error(str(e))
254+
255+
error_output = RequestOutput(
256+
prompt=inputs,
257+
prompt_token_ids=input_tokens,
258+
finished=True,
259+
request_id=request_id,
260+
outputs=[
261+
CompletionOutput(
262+
index=0,
263+
text=output_text,
264+
token_ids=token_list,
265+
cumulative_logprob=-1.0,
266+
finish_reason="error",
267+
stop_reason=str(e),
268+
)
269+
],
270+
)
271+
yield error_output
272+
273+
else:
274+
try:
275+
token_list = self.model.generate(**inputs, **generation_options)[0]
276+
output_text = self.processor.tokenizer.decode(
277+
token_list, skip_special_tokens=True
278+
)
279+
280+
yield RequestOutput(
281+
request_id=request_id,
282+
prompt=inputs,
283+
prompt_token_ids=input_tokens,
284+
finished=True,
285+
outputs=[
286+
CompletionOutput(
287+
index=0,
288+
text=output_text,
289+
token_ids=token_list,
290+
cumulative_logprob=-1.0,
291+
finish_reason="stop",
292+
)
293+
],
294+
)
295+
296+
except Exception as e:
297+
logger.error(str(e))
298+
299+
error_output = RequestOutput(
300+
prompt=inputs,
301+
prompt_token_ids=input_tokens,
302+
finished=True,
303+
request_id=request_id,
304+
outputs=[
305+
CompletionOutput(
306+
index=0,
307+
text=output_text,
308+
token_ids=token_list,
309+
cumulative_logprob=-1.0,
310+
finish_reason="error",
311+
stop_reason=str(e),
312+
)
313+
],
314+
)
315+
yield error_output
211316

212317

213318
async def generate(

0 commit comments

Comments
 (0)