@@ -40,7 +40,8 @@ def main(argv):
4040 max_output_length = 1024
4141
4242 profiling_output = FLAGS .profiling_output
43- if profiling_output :
43+ profiling_prefill = FLAGS .profiling_prefill
44+ if profiling_output and profiling_prefill :
4445 jax .profiler .start_trace (profiling_output )
4546
4647 decode_state = engine .init_decode_state ()
@@ -68,7 +69,11 @@ def main(argv):
6869 print (f"---- Streaming decode started on #slot{ slot } ." )
6970 complete = np .zeros ((1 ,), dtype = np .bool_ )
7071 while True :
72+ if profiling_output and not profiling_prefill :
73+ jax .profiler .start_trace (profiling_output )
7174 decode_state , result_tokens = engine .generate (params , decode_state )
75+ if profiling_output and not profiling_prefill :
76+ jax .profiler .stop_trace ()
7277 result_tokens = result_tokens .convert_to_numpy ()
7378 output , complete = token_utils .process_result_tokens (
7479 tokenizer = tokenizer ,
@@ -87,7 +92,7 @@ def main(argv):
8792 print ("---- All output text." )
8893 print (tokenizer .decode (sampled_tokens_list ))
8994
90- if profiling_output :
95+ if profiling_output and profiling_prefill :
9196 jax .profiler .stop_trace ()
9297
9398
0 commit comments