Skip to content

Commit 21d1290

Browse files
authored
Refactor interactive (#6)
1 parent 2ba06a2 commit 21d1290

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

run_interactive.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from absl import app
22
from absl import flags
33
from absl import logging
4+
import random
5+
from typing import List
46
import sys
57
import jax
68
import jax.numpy as jnp
79
import numpy as np
810

911
from jetstream.engine import token_utils
1012
from absl.testing import absltest
13+
from colorama import Fore, Back, Style
14+
1115

1216
import os
1317
import sys
@@ -16,6 +20,7 @@
1620
import time
1721
import logging
1822

23+
1924
logging.getLogger().setLevel(logging.ERROR)
2025

2126

@@ -86,21 +91,28 @@ def main(argv):
8691
params = engine.load_params()
8792
print('Load params ', time.perf_counter() - start)
8893

89-
prefill_times = {}
90-
slot = jnp.int32(0)
9194
metadata = engine.get_tokenizer()
9295
vocab = token_utils.load_vocab(
9396
metadata.path, metadata.extra_ids)
94-
tokenizer = vocab.tokenizer
97+
stop_tokens = [vocab.eos_id, vocab.pad_id]
98+
max_output_length = 1024
9599

96-
while True:
97-
# text = input('Text >>>> ')
98-
text = 'I believe the meaning of life is'
99-
decode_state = engine.init_decode_state()
100-
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True)
101-
# tokens = tokenizer.encode(text)
102-
# tokens = [tokenizer.bos_id()] + tokens
103-
print('Encoded tokens are: ', tokens)
100+
if _PROFILING_OUTPUT.value:
101+
jax.profiler.start_trace(_PROFILING_OUTPUT.value)
102+
103+
decode_state = engine.init_decode_state()
104+
prompts: List[str] = [
105+
"I believe the meaning of life is",
106+
"To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList<Person>`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList<Person> peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
107+
"<s>[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<</SYS>>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
108+
"<s>[INST] <<SYS>>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<</SYS>>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
109+
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
110+
]
111+
for prompt in prompts:
112+
slot = random.randint(0, _BATCH_SIZE.value)
113+
tokens, true_length = token_utils.tokenize_and_pad(prompt, vocab, is_bos=True)
114+
print(f"---- Input prompts are: {prompt}")
115+
print(f"---- Encoded tokens are: {tokens}")
104116

105117
prefill_result = engine.prefill(
106118
params=params, padded_tokens=tokens, true_length=true_length
@@ -109,25 +121,36 @@ def main(argv):
109121
prefill_result, decode_state, slot=slot
110122
)
111123
sampled_tokens_list = []
112-
for i in range(100):
113-
decode_state, sampled_tokens = engine.generate(
124+
print(f"---- Streaming decode started on #slot{slot}.")
125+
while True:
126+
decode_state, result_tokens = engine.generate(
114127
params, decode_state
115128
)
116-
tstart, end = sampled_tokens.tokens_idx
117-
sampled_tokens_list.append(sampled_tokens.data[0, 0].item())
118129

119-
print('---- ans ----')
120-
print(sampled_tokens_list)
121-
print(tokenizer.decode(sampled_tokens_list))
122-
break
130+
slot_data = result_tokens.get_result_at_slot(slot)
131+
slot_tokens = slot_data.tokens
132+
slot_lengths = slot_data.lengths
133+
134+
token_id = slot_tokens[slot, 0].item()
135+
if slot_lengths > max_output_length or token_id in stop_tokens:
136+
break
137+
138+
sampled_tokens_list.append(token_id)
139+
output = token_utils.mix_decode(vocab, token_id)
140+
print(Fore.GREEN + output, end="", flush=True)
123141

142+
print(Style.RESET_ALL + "\n")
143+
print("---- Streaming decode finished.")
144+
145+
146+
print("---- All output tokens.")
147+
print(sampled_tokens_list)
124148

125149

126150
if _PROFILING_OUTPUT.value:
127151
jax.profiler.stop_trace()
128152

129153

130-
131154
if __name__ == "__main__":
132155
import os
133156
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

0 commit comments

Comments
 (0)