Skip to content

Commit 97aaeae

Browse files
authored
Add different token sampling algorithms to decoder. (#123)
1 parent fe8dbde commit 97aaeae

File tree

5 files changed

+157
-47
lines changed

5 files changed

+157
-47
lines changed

deps/JetStream

jetstream_pt/config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,29 @@
8383
"for performance tuning and debugging only",
8484
required=False,
8585
)
86+
flags.DEFINE_float(
87+
"temperature",
88+
1.0,
89+
"temperature parameter for scaling probability."
90+
"Only invoked when sampling algorithm is set to"
91+
"weighted or topk",
92+
)
93+
flags.DEFINE_string(
94+
"sampling_algorithm",
95+
"greedy",
96+
"sampling algorithm to use. Options:"
97+
"('greedy', 'weighted', 'neucleus', 'topk')",
98+
)
99+
flags.DEFINE_float(
100+
"nucleus_topp",
101+
0.0,
102+
"restricting to p probability mass before sampling",
103+
)
104+
flags.DEFINE_integer(
105+
"topk",
106+
0,
107+
"size of top k used when sampling next token",
108+
)
86109

87110

88111
def create_quantization_config_from_flags():
@@ -148,6 +171,10 @@ def create_engine_from_config_flags():
148171
shard_on_batch=FLAGS.shard_on_batch,
149172
ragged_mha=FLAGS.ragged_mha,
150173
starting_position=FLAGS.starting_position,
174+
temperature=FLAGS.temperature,
175+
sampling_algorithm=FLAGS.sampling_algorithm,
176+
nucleus_topp=FLAGS.nucleus_topp,
177+
topk=FLAGS.topk,
151178
)
152179

153180
print("Initialize engine", time.perf_counter() - start)

jetstream_pt/engine.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import numpy as np
2929

3030
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
31+
from jetstream.engine import sampling_utils
3132
import torch_xla2
3233
from torch.utils import _pytree as pytree
3334

@@ -85,6 +86,7 @@ def __init__(
8586
self.pt_model = pt_model
8687
self.env = env
8788
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
89+
self.rng = jax.random.PRNGKey(0)
8890

8991
self.y_sharding = env.sharding_by_axis(1)
9092
self.x_sharding = env.sharding_by_axis(0)
@@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
220222
if len(logits.shape) == 2:
221223
logits = jnp.expand_dims(logits, 0)
222224
return (
223-
jnp.argmax(logits[:, -1], axis=-1)
225+
sampling_utils.sampling(
226+
logits[:, -1],
227+
self.rng,
228+
self.env.sampling_algorithm,
229+
self.env.topk,
230+
self.env.nucleus_topp,
231+
self.env.temperature,
232+
)
224233
.reshape(batch_size, -1)
225234
.astype(jnp.int32)
226235
)
@@ -248,9 +257,16 @@ def prefill(
248257
input_indexes,
249258
)
250259
if len(logits.shape) == 3: # b, seqlen, num words
251-
logits = logits[0]
252-
253-
token = jnp.argmax(logits[true_length - 1])
260+
logits = logits[0] # seqlen, num words
261+
262+
token = sampling_utils.sampling(
263+
logits[true_length - 1],
264+
self.rng,
265+
self.env.sampling_algorithm,
266+
self.env.topk,
267+
self.env.nucleus_topp,
268+
self.env.temperature,
269+
)
254270

255271
# truncate to true_length didnt work need to be out side of jit
256272
# caches = [
@@ -762,6 +778,10 @@ def create_pytorch_engine(
762778
shard_on_batch=False,
763779
ragged_mha=False,
764780
starting_position=512,
781+
temperature=None,
782+
sampling_algorithm="greedy",
783+
nucleus_topp=None,
784+
topk=None,
765785
) -> PyTorchEngine:
766786
"""Returns: The pytorch engine."""
767787

@@ -827,6 +847,10 @@ def create_pytorch_engine(
827847
shard_on_batch=shard_on_batch,
828848
ragged_mha=ragged_mha,
829849
starting_position=starting_position,
850+
temperature=temperature,
851+
sampling_algorithm=sampling_algorithm,
852+
nucleus_topp=nucleus_topp,
853+
topk=topk,
830854
)
831855

832856
if shard_on_batch and sharding_config:

jetstream_pt/environment.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,19 @@ class JetEngineEnvironmentData:
100100
# Starting position
101101
starting_position: int = 512
102102

103+
# Variables used in token sampling
104+
# sampling algorithm to use ("greedy", "weighted", "neucleus", "topk")
105+
sampling_algorithm: str = "greedy"
106+
107+
# size of top k used when sampling next token
108+
topk: int = 0
109+
110+
# restricting to p probability mass before sampling
111+
nucleus_topp: float = 0.0
112+
113+
# temperature parameter for scaling probability
114+
temperature: float = 1.0
115+
103116

104117
# pylint: disable-next=all
105118
class JetEngineEnvironment:

tests/test_engine.py

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,92 @@
1414

1515
# pylint: disable=all
1616

17-
18-
# This model will output tokens with value of 2
19-
# and will update caches with value of 1.0
20-
# class Dummy(torch.nn.Module):
21-
22-
# def __init__(self):
23-
# super().__init__()
24-
# self.params = None
25-
26-
# def forward(
27-
# self,
28-
# tokens: torch.Tensor,
29-
# input_pos: torch.Tensor,
30-
# caches: List[Any],
31-
# mask,
32-
# ):
33-
# batch_size, seqlen = tokens.shape
34-
# for cache in caches:
35-
# cache.update(torch.ones((batch_size, seqlen)))
36-
# return torch.ones((batch_size, seqlen), dtype=torch.int32) * 2
37-
38-
39-
# class EngineTest(unittest.TestCase):
40-
41-
# def _make_small_engine(self, quantize=False):
42-
# env_data = JetEngineEnvironmentData()
43-
# env_data.max_input_sequence_length = 128
44-
# env_data.max_input_sequence_length = 128
45-
# env_data.cache_sequence_length = 128
46-
# env_data.model_type = 'llama-2-tiny'
47-
# if quantize:
48-
# env_data.enable_kv_quantization = True
49-
# env_data.enable_weight_quantization = True
50-
51-
# env = JetEngineEnvironment(env_data)
52-
# model = Dummy()
53-
# model.params = env._model_arg # llama's model arg
54-
55-
# engine = PyTorchEngine(model, env)
56-
# return engine
17+
import unittest
18+
import jax
19+
import jax.numpy as jnp
20+
21+
from jetstream_pt.third_party.llama import model_exportable
22+
from jetstream_pt.engine import PyTorchEngine
23+
from tests import helpers
24+
25+
26+
class EngineTest(unittest.TestCase):
27+
28+
def setup(self):
29+
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
30+
model_ours = model_exportable.Transformer(model_arg, env)
31+
engine = PyTorchEngine(pt_model=model_ours, env=env)
32+
engine.rng = jax.random.PRNGKey(0)
33+
return engine
34+
35+
def test_sampling_2D(self):
36+
# test greedy
37+
engine = self.setup()
38+
self.assertEqual(engine.env.sampling_algorithm, "greedy")
39+
logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]])
40+
token = engine._sampling(logits, batch_size=1)
41+
self.assertEqual(token, jnp.array([[0]]))
42+
self.assertTrue(jnp.isdtype(token, jnp.int32))
43+
44+
# test weighted
45+
engine.env.sampling_algorithm = "weighted"
46+
engine.env.temperature = 5.0
47+
token = engine._sampling(logits, batch_size=1)
48+
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
49+
self.assertTrue(jnp.isdtype(token, jnp.int32))
50+
51+
# test topk
52+
engine.env.sampling_algorithm = "topk"
53+
engine.env.temperature = 5.0
54+
engine.env.topk = 4
55+
token = engine._sampling(logits, batch_size=1)
56+
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
57+
self.assertTrue(jnp.isdtype(token, jnp.int32))
58+
59+
# test nucleus
60+
engine.env.sampling_algorithm = "nucleus"
61+
engine.env.temperature = 0.0
62+
engine.env.nucleus_topp = 0.8
63+
token = engine._sampling(logits, batch_size=1)
64+
self.assertTrue(jnp.array_equal(token, jnp.array([[0]])))
65+
self.assertTrue(jnp.isdtype(token, jnp.int32))
66+
67+
def test_sampling_3D(self):
68+
# test greedy
69+
engine = self.setup()
70+
self.assertEqual(engine.env.sampling_algorithm, "greedy")
71+
logits = jnp.array(
72+
[
73+
[[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]],
74+
[[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]],
75+
]
76+
)
77+
token = engine._sampling(logits, batch_size=2)
78+
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]])))
79+
self.assertTrue(jnp.isdtype(token, jnp.int32))
80+
81+
# test weighted
82+
engine.env.sampling_algorithm = "weighted"
83+
engine.env.temperature = 10.0
84+
token = engine._sampling(logits, batch_size=2)
85+
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
86+
self.assertTrue(jnp.isdtype(token, jnp.int32))
87+
88+
# test topk
89+
engine.env.sampling_algorithm = "topk"
90+
engine.env.temperature = 1.0
91+
engine.env.topk = 3
92+
token = engine._sampling(logits, batch_size=2)
93+
self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]])))
94+
self.assertTrue(jnp.isdtype(token, jnp.int32))
95+
96+
# test nucleus
97+
engine.env.sampling_algorithm = "nucleus"
98+
engine.env.temperature = 1.0
99+
engine.env.nucleus_topp = 0.8
100+
token = engine._sampling(logits, batch_size=2)
101+
self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]])))
102+
self.assertTrue(jnp.isdtype(token, jnp.int32))
57103

58104

59105
# def test_insert(self):
@@ -229,5 +275,5 @@
229275
# # prefill
230276

231277

232-
# if __name__ == '__main__':
233-
# unittest.main()
278+
if __name__ == "__main__":
279+
unittest.main()

0 commit comments

Comments
 (0)