Skip to content

Commit 1ecf4d7

Browse files
authored
Update run_interactive.py with finer control of profiler. (#103)
1 parent c360158 commit 1ecf4d7

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

run_interactive.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def main(argv):
4040
max_output_length = 1024
4141

4242
profiling_output = FLAGS.profiling_output
43-
if profiling_output:
43+
profiling_prefill = FLAGS.profiling_prefill
44+
if profiling_output and profiling_prefill:
4445
jax.profiler.start_trace(profiling_output)
4546

4647
decode_state = engine.init_decode_state()
@@ -68,7 +69,11 @@ def main(argv):
6869
print(f"---- Streaming decode started on #slot{slot}.")
6970
complete = np.zeros((1,), dtype=np.bool_)
7071
while True:
72+
if profiling_output and not profiling_prefill:
73+
jax.profiler.start_trace(profiling_output)
7174
decode_state, result_tokens = engine.generate(params, decode_state)
75+
if profiling_output and not profiling_prefill:
76+
jax.profiler.stop_trace()
7277
result_tokens = result_tokens.convert_to_numpy()
7378
output, complete = token_utils.process_result_tokens(
7479
tokenizer=tokenizer,
@@ -87,7 +92,7 @@ def main(argv):
8792
print("---- All output text.")
8893
print(tokenizer.decode(sampled_tokens_list))
8994

90-
if profiling_output:
95+
if profiling_output and profiling_prefill:
9196
jax.profiler.stop_trace()
9297

9398

0 commit comments

Comments
 (0)