1313from transformers import (
1414 AutoConfig ,
1515 AutoProcessor ,
16+ TextStreamer ,
1617 PreTrainedTokenizer ,
1718 PreTrainedTokenizerFast ,
1819 TextIteratorStreamer ,
@@ -56,15 +57,15 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
5657 except Exception :
5758 logger .info ("Attempt to load slower tokenizer" )
5859 self .tokenizer = PreTrainedTokenizer .from_pretrained (self .model_path )
59- self .tokenizer_stream = TextIteratorStreamer (
60- self .tokenizer ,
61- skip_prompt = True ,
62- skip_special_tokens = True
63- )
6460 logger .info ("Tokenizer created" )
6561
6662 # non vision
6763 if not vision :
64+ self .tokenizer_stream = TextIteratorStreamer (
65+ self .tokenizer ,
66+ skip_prompt = True ,
67+ skip_special_tokens = True
68+ )
6869 try :
6970 self .model = OVModelForCausalLM .from_pretrained (
7071 self .model_path ,
@@ -117,6 +118,14 @@ def __init__(self, model_path: str, vision: bool, device: str = "gpu"):
117118 )
118119 logger .info ("Processor loaded" )
119120 print ("processor directory: " ,dir (self .processor ))
121+ self .tokenizer_stream = TextIteratorStreamer (
122+ self .processor ,
123+ ** {
124+ "skip_special_tokens" : True ,
125+ "skip_prompt" : True ,
126+ "clean_up_tokenization_spaces" : False ,
127+ },
128+ )
120129
121130
122131 async def generate_vision (
@@ -139,15 +148,21 @@ async def generate_vision(
139148 assert "image" in mime_type
140149
141150 image = Image .open (BytesIO (file_data ))
142- input_token_length = self .processor .calc_num_image_tokens (image )[0 ]
151+ image_token_length = self .processor .calc_num_image_tokens (image )[0 ]
152+ prompt_token_length = len (self .tokenizer .encode (prompt_text , return_tensors = "pt" )[0 ])
153+
154+ input_token_length = image_token_length + prompt_token_length
155+
156+ # logger.debug(f"Prompt token length: {prompt_token_length}")
157+ # logger.debug(f"Image token length: {image_token_length}")
158+
143159 max_tokens = sampling_params .max_tokens
144160
145161 assert input_token_length is not None
146162
147163 if input_token_length + max_tokens > self .max_model_len :
148164 raise ValueError ("Exceed Context Length" )
149165
150-
151166 messages = [
152167 {'role' : 'user' , 'content' : f'<|image_1|>\n { prompt_text } ' }
153168 ]
@@ -156,58 +171,148 @@ async def generate_vision(
156171 tokenize = False ,
157172 add_generation_prompt = True
158173 )
159- print ("Prompt: " ,prompt )
174+ # print("Prompt: ", prompt)
160175
161- try :
162- inputs = self .processor (prompt , [image ], return_tensors = "pt" )
163- print (f"Processed inputs" )
164- except Exception as e :
165- print (f"Error processing inputs: { e } " )
176+ inputs = self .processor (prompt , [image ], return_tensors = "pt" )
166177
178+ generation_options = {
179+ 'max_new_tokens' : max_tokens ,
180+ 'do_sample' : False ,
181+ }
167182
168183 token_list : List [int ] = []
169184 output_text : str = ""
170-
171- try :
172- generation_options = {
173- 'max_new_tokens' : max_tokens ,
174- 'do_sample' : False ,
175- }
176- token_list = self .model .generate (
177- ** inputs ,
178- eos_token_id = self .processor .tokenizer .eos_token_id ,
179- ** generation_options
180- )
181- print (f"Generated token list" )
182- except Exception as e :
183- print (f"Error during token generation: { e } " )
184-
185- # Decode each element in the response
186- try :
187- decoded_text = [self .processor .tokenizer .decode (ids , skip_special_tokens = True ) for ids in token_list ]
188- print (f"Decoded text: { decoded_text } " )
189- except Exception as e :
190- print (f"Error decoding text: { e } " )
191-
192- # Join the decoded text if needed
193- output_text = ' ' .join (decoded_text ).strip ()
194- print (output_text )
195-
196- yield RequestOutput (
197- request_id = request_id ,
198- prompt = inputs ,
199- prompt_token_ids = input_tokens ,
200- finished = True ,
201- outputs = [
202- CompletionOutput (
203- index = 0 ,
204- text = output_text ,
205- token_ids = token_list ,
206- cumulative_logprob = - 1.0 ,
207- finish_reason = "stop" ,
185+ if stream :
186+ generation_options ["streamer" ] = self .tokenizer_stream
187+ # Include the inputs in the generation_options
188+ generation_kwargs = {** inputs , ** generation_options }
189+
190+ if RECORD_TIMING :
191+ started_timestamp = time .time ()
192+ first_token_timestamp = 0
193+ first = True
194+ new_tokens = []
195+
196+ try :
197+ thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
198+ thread .start ()
199+ output_text = ""
200+ first = True
201+ for new_text in self .tokenizer_stream :
202+ if new_text == "" :
203+ continue
204+ if RECORD_TIMING :
205+ if first :
206+ first_token_timestamp = time .time ()
207+ first = False
208+ output_text += new_text
209+ token_list = self .processor .tokenizer .encode (output_text , return_tensors = "pt" )
210+
211+ yield RequestOutput (
212+ request_id = request_id ,
213+ prompt = inputs ,
214+ prompt_token_ids = input_tokens ,
215+ finished = False ,
216+ outputs = [
217+ CompletionOutput (
218+ index = 0 ,
219+ text = output_text ,
220+ token_ids = token_list [0 ],
221+ cumulative_logprob = - 1.0 ,
222+ )
223+ ],
224+ )
225+
226+ if RECORD_TIMING :
227+ new_tokens = token_list [0 ]
228+
229+ yield RequestOutput (
230+ request_id = request_id ,
231+ prompt = inputs ,
232+ prompt_token_ids = input_tokens ,
233+ finished = True ,
234+ outputs = [
235+ CompletionOutput (
236+ index = 0 ,
237+ text = output_text ,
238+ token_ids = token_list [0 ],
239+ cumulative_logprob = - 1.0 ,
240+ finish_reason = "stop" ,
241+ )
242+ ],
208243 )
209- ],
210- )
244+
245+ if RECORD_TIMING :
246+ prompt_time = first_token_timestamp - started_timestamp
247+ run_time = time .time () - first_token_timestamp
248+ logger .info (
249+ f"Prompt length: { len (input_tokens )} , New tokens: { len (new_tokens )} , Time to first: { (prompt_time ):.2f} s, Prompt tokens per second: { len (input_tokens )/ prompt_time :.2f} tps, New tokens per second: { len (new_tokens )/ run_time :.2f} tps"
250+ )
251+
252+ except Exception as e :
253+ logger .error (str (e ))
254+
255+ error_output = RequestOutput (
256+ prompt = inputs ,
257+ prompt_token_ids = input_tokens ,
258+ finished = True ,
259+ request_id = request_id ,
260+ outputs = [
261+ CompletionOutput (
262+ index = 0 ,
263+ text = output_text ,
264+ token_ids = token_list ,
265+ cumulative_logprob = - 1.0 ,
266+ finish_reason = "error" ,
267+ stop_reason = str (e ),
268+ )
269+ ],
270+ )
271+ yield error_output
272+
273+ else :
274+ try :
275+ token_list = self .model .generate (** inputs , ** generation_options )[0 ]
276+ output_text = self .processor .tokenizer .decode (
277+ token_list , skip_special_tokens = True
278+ )
279+
280+ yield RequestOutput (
281+ request_id = request_id ,
282+ prompt = inputs ,
283+ prompt_token_ids = input_tokens ,
284+ finished = True ,
285+ outputs = [
286+ CompletionOutput (
287+ index = 0 ,
288+ text = output_text ,
289+ token_ids = token_list ,
290+ cumulative_logprob = - 1.0 ,
291+ finish_reason = "stop" ,
292+ )
293+ ],
294+ )
295+
296+ except Exception as e :
297+ logger .error (str (e ))
298+
299+ error_output = RequestOutput (
300+ prompt = inputs ,
301+ prompt_token_ids = input_tokens ,
302+ finished = True ,
303+ request_id = request_id ,
304+ outputs = [
305+ CompletionOutput (
306+ index = 0 ,
307+ text = output_text ,
308+ token_ids = token_list ,
309+ cumulative_logprob = - 1.0 ,
310+ finish_reason = "error" ,
311+ stop_reason = str (e ),
312+ )
313+ ],
314+ )
315+ yield error_output
211316
212317
213318 async def generate (
0 commit comments