@@ -140,7 +140,7 @@ def main(argv):
140140 ]
141141 for prompt in prompts :
142142 slot = random .randint (0 , _BATCH_SIZE .value - 1 )
143- tokens , true_length = tokenizer .encode (prompt , is_bos = True )
143+ tokens , true_length = tokenizer .encode (prompt )
144144
145145 print (f"---- Input prompts are: { prompt } " )
146146 print (f"---- Encoded tokens are: { tokens } " )
@@ -157,12 +157,15 @@ def main(argv):
157157 while True :
158158 decode_state , result_tokens = engine .generate (params , decode_state )
159159 result_tokens = result_tokens .convert_to_numpy ()
160- output , complete = tokenizer .decode (
161- slot , max_output_length , result_tokens , complete
162- )
163- if complete [0 ]:
160+ res = result_tokens .get_result_at_slot (slot )
161+ stop_tokens = set (tokenizer .tokenizer .stop_tokens )
162+ stop_tokens .add (tokenizer .pad_id )
163+ if (
164+ res .tokens [0 ][0 ] in stop_tokens
165+ or len (sampled_tokens_list ) > max_output_length
166+ ):
164167 break
165- token_id = output [0 ][0 ]
168+ token_id = res . tokens [0 ][0 ]
166169 sampled_tokens_list .append (token_id )
167170 # output_str = tokenizer.decode_str([token_id])
168171 # print(Fore.GREEN + output_str, end="", flush=True)
@@ -173,7 +176,7 @@ def main(argv):
173176 print ("---- All output tokens." )
174177 print (sampled_tokens_list )
175178 print ("---- All output text." )
176- print (tokenizer .decode_str (sampled_tokens_list ))
179+ print (tokenizer .decode (sampled_tokens_list ))
177180
178181 if _PROFILING_OUTPUT .value :
179182 jax .profiler .stop_trace ()
0 commit comments